Skip to content

Fix Remat error when called with a model #21094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 28, 2025
31 changes: 19 additions & 12 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import collections
import functools
import inspect
import math
import warnings
Expand Down Expand Up @@ -1053,11 +1054,13 @@ def stateless_call(
if self._remat_mode is not None:
outputs = self.rematerialized_call(
self.quantized_call, *args, **kwargs
)
)(*args, **kwargs)
else:
outputs = self.quantized_call(*args, **kwargs)
elif self._remat_mode is not None:
outputs = self.rematerialized_call(self.call, *args, **kwargs)
outputs = self.rematerialized_call(self.call, *args, **kwargs)(
*args, **kwargs
)
else:
outputs = self.call(*args, **kwargs)
if return_losses:
Expand Down Expand Up @@ -1611,13 +1614,13 @@ def compute_size(x):

# Full rematerialization
if self._remat_mode.mode == "full":
return remat.remat(layer_call)(*args, **kwargs)
return remat.remat(layer_call)

# Apply rematerialization to specific layers
elif self._remat_mode.mode == "list_of_layers" and (
self.name in self._remat_mode.layer_names
):
return remat.remat(layer_call)(*args, **kwargs)
return remat.remat(layer_call)

# Apply rematerialization based on output size threshold
elif self._remat_mode.mode == "larger_than":
Expand All @@ -1629,20 +1632,24 @@ def compute_size(x):
output_size
and output_size > self._remat_mode.output_size_threshold
):
return remat.remat(layer_call)(*args, **kwargs)
return remat.remat(layer_call)
elif self._remat_mode.mode == "activations":
has_activation = (
hasattr(self, "activation") and self.activation is not None
)
if has_activation:
not_rematted_activation = self.activation
try:
self.activation = remat.remat(not_rematted_activation)
return layer_call(*args, **kwargs)
finally:
self.activation = not_rematted_activation

return layer_call(*args, **kwargs)
@functools.wraps(layer_call)
def rematerialized_activation_call_wrapper(*args, **kwargs):
original_activation = self.activation
self.activation = remat.remat(original_activation)
try:
return layer_call(*args, **kwargs)
finally:
self.activation = original_activation

return rematerialized_activation_call_wrapper
return layer_call


def is_backend_tensor_or_symbolic(x, allow_none=False):
Expand Down
3 changes: 2 additions & 1 deletion keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from keras.src.backend.common import global_state
from keras.src.backend.common.remat import RematScope
from keras.src.models import Model
from keras.src.utils import traceback_utils


class MockRemat:
Expand Down Expand Up @@ -237,7 +238,7 @@ def test_functional_model_with_remat(self):
self.skipTest(
"remat is not supported in openvino and numpy backends."
)
# traceback_utils.enable_traceback_filtering()
traceback_utils.enable_traceback_filtering()
mock_remat = MockRemat()
with mock.patch(
"keras.src.backend.common.remat.remat", wraps=mock_remat
Expand Down
16 changes: 12 additions & 4 deletions keras/src/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ def __call__(self, *args, **kwargs):
else:
if getattr(self, "_remat_mode", None) is not None:
if getattr(self, "quantization_mode", None) is not None:
call_fn = self.rematerialized_call(self.quantized_call)
call_fn = self.rematerialized_call(
self.quantized_call,
*args,
**kwargs,
)
else:
call_fn = self.rematerialized_call(self.call)
call_fn = self.rematerialized_call(
self.call, *args, **kwargs
)
else:
if getattr(self, "quantization_mode", None) is not None:
call_fn = self.quantized_call
Expand All @@ -58,9 +64,11 @@ def __call__(self, *args, **kwargs):
if getattr(self, "quantization_mode", None) is not None:
return self.rematerialized_call(
self.quantized_call, *args, **kwargs
)
)(*args, **kwargs)
else:
return self.rematerialized_call(self.call, *args, **kwargs)
return self.rematerialized_call(self.call, *args, **kwargs)(
*args, **kwargs
)
else:
if getattr(self, "quantization_mode", None) is not None:
return self.quantized_call(*args, **kwargs)
Expand Down