Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tf_keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,7 +1894,7 @@ def on_train_begin(self, logs=None):
# TrainingState is used to manage the training state needed for
# failure-recovery of a worker in training.

if self.model._distribution_strategy and not isinstance(
if self.model.distribute_strategy and not isinstance(
self.model.distribute_strategy, self._supported_strategies
):
Comment on lines +1897 to 1899

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this change correctly uses the public distribute_strategy property and fixes the AttributeError, it introduces a potential issue. The property self.model.distribute_strategy falls back to tf.distribute.get_strategy(), which returns a DefaultDistributionStrategy instance if no other strategy is active. This means the property is never None, and the if condition will always be evaluated.

Since DefaultDistributionStrategy is likely not in self._supported_strategies, this will cause a NotImplementedError to be raised for users who are not using any distribution strategy, which is not the intended behavior.

The logic should be adjusted to explicitly allow the DefaultDistributionStrategy to pass through without raising an error. This ensures the callback only validates strategies when a non-default one is actually in use.

        if not isinstance(
            self.model.distribute_strategy,
            (tf.distribute.DefaultDistributionStrategy, *self._supported_strategies),
        ):

raise NotImplementedError(
Expand Down