Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back loss information for multiple outputs #20023

Merged
merged 7 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix TorchWrapper of training args
  • Loading branch information
james77777778 committed Jul 22, 2024
commit 603525c2ab05129f805184dd5e8f5b9a9628eb7d
9 changes: 2 additions & 7 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,9 +1013,6 @@ def _symbolic_build(self, iterator=None, data_batch=None):
self.optimizer is not None and not self.optimizer.built
)
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
if backend.backend() == "torch":
original_training = self.training
self.eval()
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
Expand All @@ -1038,7 +1035,7 @@ def to_symbolic_input(v):

# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
y_pred = backend.compute_output_spec(self, x, training=False)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
Expand Down Expand Up @@ -1068,10 +1065,8 @@ def to_symbolic_input(v):
y,
y_pred,
sample_weight=sample_weight,
training=False,
)
if backend.backend() == "torch":
if original_training:
self.train()
if optimizer_unbuilt:
# Build optimizer
self.optimizer.build(self.trainable_variables)
Expand Down
6 changes: 5 additions & 1 deletion keras/src/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def _track_module_parameters(self):
self._track_variable(variable)
self.built = True

def call(self, *args, **kwargs):
def call(self, *args, training=None, **kwargs):
if training is False:
self.eval()
else:
self.train()
return self.module(*args, **kwargs)

def save_own_variables(self, store):
Expand Down
50 changes: 47 additions & 3 deletions keras/src/utils/torch_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
self.torch_wrappers.append(TorchModuleWrapper(torch_model))
self.fc = layers.Dense(1)

def call(self, x):
def call(self, x, training=None):
for wrapper in self.torch_wrappers:
x = wrapper(x)
x = wrapper(x, training=training)
return self.fc(x)

def get_config(self):
Expand All @@ -49,7 +49,7 @@ def __init__(self, *args, **kwargs):
self.fc2 = torch.nn.Linear(4, 4)
self.fc3 = layers.Dense(2)

def call(self, x):
def call(self, x, training=None):
return self.fc3(self.fc2(self.bn1(self.fc1(x))))


Expand Down Expand Up @@ -82,6 +82,50 @@ def test_basic_usage(self, use_batch_norm, num_torch_layers):
model.compile(optimizer="sgd", loss="mse")
model.fit(np.random.random((3, 2)), np.random.random((3, 1)))

@parameterized.named_parameters(
(
"explicit_torch_wrapper",
Classifier,
{"use_batch_norm": True, "num_torch_layers": 1},
),
("implicit_torch_wrapper", ClassifierWithNoSpecialCasing, {}),
)
def test_training_args(self, cls, kwargs):
model = cls(**kwargs)
model(np.random.random((3, 2)), training=False) # Eager call to build
ref_weights = model.get_weights()
ref_running_mean = backend.convert_to_numpy(
model.torch_wrappers[0].module[-1].running_mean
if cls is Classifier
else model.bn1.module.running_mean
)

# Test training=False doesn't affect model weights
model(np.random.random((3, 2)), training=False)
weights = model.get_weights()
for w, ref_w in zip(weights, ref_weights):
self.assertAllClose(w, ref_w)

# Test training=None affects BN's stats
model.set_weights(ref_weights) # Restore previous weights
model(np.random.random((3, 2)))
running_mean = backend.convert_to_numpy(
model.torch_wrappers[0].module[-1].running_mean
if cls is Classifier
else model.bn1.module.running_mean
)
self.assertNotAllClose(running_mean, ref_running_mean)

# Test training=True affects BN's stats
model.set_weights(ref_weights) # Restore previous weights
model(np.random.random((3, 2)), training=True)
running_mean = backend.convert_to_numpy(
model.torch_wrappers[0].module[-1].running_mean
if cls is Classifier
else model.bn1.module.running_mean
)
self.assertNotAllClose(running_mean, ref_running_mean)

def test_module_autowrapping(self):
model = ClassifierWithNoSpecialCasing()
self.assertIsInstance(model.fc1, TorchModuleWrapper)
Expand Down