From c28cf75fbc1f9831ea8ab38905023e936f623261 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:28:17 +0800 Subject: [PATCH] Minor updates and add tests --- keras/src/backend/__init__.py | 2 + keras/src/backend/common/symbolic_scope.py | 2 + .../src/backend/common/symbolic_scope_test.py | 26 ++++++++++ keras/src/backend/numpy/trainer.py | 14 +++++- keras/src/trainers/compile_utils.py | 15 ++---- keras/src/trainers/trainer.py | 8 ++-- keras/src/trainers/trainer_test.py | 47 +++++++++++++++++++ 7 files changed, 97 insertions(+), 17 deletions(-) create mode 100644 keras/src/backend/common/symbolic_scope_test.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 5c7fa2235207..794fe3ca3645 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -14,6 +14,8 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.common.variables import AutocastScope from keras.src.backend.common.variables import get_autocast_scope from keras.src.backend.common.variables import is_float_dtype diff --git a/keras/src/backend/common/symbolic_scope.py b/keras/src/backend/common/symbolic_scope.py index 780032d57282..15cd7a5ee059 100644 --- a/keras/src/backend/common/symbolic_scope.py +++ b/keras/src/backend/common/symbolic_scope.py @@ -4,6 +4,8 @@ @keras_export("keras.SymbolicScope") class SymbolicScope: + """Scope to indicate the symbolic stage.""" + def __enter__(self): self.original_scope = get_symbolic_scope() global_state.set_global_attribute("symbolic_scope", self) diff --git a/keras/src/backend/common/symbolic_scope_test.py b/keras/src/backend/common/symbolic_scope_test.py new file mode 100644 index 000000000000..092dcfe0748c --- /dev/null +++ b/keras/src/backend/common/symbolic_scope_test.py @@ -0,0 +1,26 @@ +import numpy as np + +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope + + +class TestSymbolicScope(testing.TestCase): + def test_basic_flow(self): + + # Define a function that behaves differently according to + # `in_symbolic_scope`. + def compute_loss(y, y_pred): + if in_symbolic_scope(): + return ops.zeros_like(y) + return ops.add(y, y_pred) + + y = ops.ones(shape=(2,)) + y_pred = ops.ones(shape=(2,)) + with SymbolicScope(): + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, np.zeros((2,))) + + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, 2 * np.ones((2,))) diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 6d40982be43e..12c3aad56b65 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch): self._compile_metrics is not None and not self._compile_metrics.built ) - if model_unbuilt or compile_metrics_unbuilt: + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -133,6 +136,15 @@ def to_symbolic_input(v): y_pred, sample_weight=sample_weight, ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + ) self._post_build() def fit( diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 51b34742fc74..114925e669df 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -445,16 +445,10 @@ def __init__( @property def metrics(self): - if not self.built: - return [] return self._metrics @property def variables(self): - # Avoiding relying on implicit tracking since - # CompileLoss may be instantiated or built in a no tracking scope. - if not self.built: - return [] vars = [] for m in self.metrics: vars.extend(m.variables) @@ -639,12 +633,9 @@ def call(self, y_true, y_pred, sample_weight=None): sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in y_true] - if len(self.metrics) == 0: - # This means that the model has a single output. We need to add a - # dummy `None` for the following `zip` to function correctly. - metrics = [None] - else: - metrics = self.metrics + + # We need to add a dummy `None` if the model has only a single output. + metrics = [None] if len(self.metrics) == 0 else self.metrics # Iterate all losses in flat form. loss_values = [] diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 22917f616449..c2f6ede112d7 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -328,9 +328,10 @@ def metrics(self): loss = self._compile_loss(y, y_pred, sample_weight) if loss is not None: losses.append(loss) + + # If in symbolic scope, skip `self.losses` to ensure we don't access + # any variables. Otherwise, it might break. if not in_symbolic_scope(): - # If in symbolic scope, skip `self.losses` to ensure we don't access - # any variables. for loss in self.losses: losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx()))) if backend.backend() != "jax" and len(losses) == 0: @@ -1042,7 +1043,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x, training=False) + y_pred = backend.compute_output_spec(self, x) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1072,7 +1073,6 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, - training=False, ) if backend.backend() == "torch": if original_training: diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index e5bc3cbdc8ff..d7c320c39bfe 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1617,6 +1617,53 @@ def test_loss_weights(self): atol=1e-3, ) + def test_symbolic_build(self): + class ExampleModelWithTrainingArgs(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + self.bn = layers.BatchNormalization(axis=-1) + + def build(self, input_shape): + self.dense.build(input_shape) + input_shape = self.dense.compute_output_shape(input_shape) + self.bn.build(input_shape) + + def call(self, x, training=None): + outputs = self.bn(self.dense(x), training=training) + return [outputs, outputs] + + model = ExampleModelWithTrainingArgs(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=[losses.MeanSquaredError(), losses.MeanSquaredError()], + metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()], + ) + x = np.ones((4, 4)) + y = np.zeros((4, 3)) + model.build(x.shape) + ref_weights = model.get_weights() + model._symbolic_build(data_batch=(x, (y, y))) + weights = model.get_weights() + + # Ensure weights are intact + self.assertEqual(len(weights), len(ref_weights)) + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Ensure `built` + self.assertTrue(model.built) + self.assertTrue(model._compile_metrics.built) + self.assertTrue(model._compile_loss.built) + + # Ensure the len of CompileLoss's metrics (loss trackers) + self.assertLen(model._compile_loss.metrics, 2) + class TrainerDistributeTest(testing.TestCase): @pytest.mark.skipif(