Skip to content

Commit 53b9866

Browse files
divyashreepathihallichiruu12
authored andcommitted
Fix Remat error when called with a model (keras-team#21094)
* add print * fix remat issue * simplify code * enable traceback filtering and update the function sig * add a wrapper for activations * change to except * add layer call decorator * fix remat call
1 parent eafdcc0 commit 53b9866

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

keras/src/layers/layer.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""
1818

1919
import collections
20+
import functools
2021
import inspect
2122
import math
2223
import warnings
@@ -1053,11 +1054,13 @@ def stateless_call(
10531054
if self._remat_mode is not None:
10541055
outputs = self.rematerialized_call(
10551056
self.quantized_call, *args, **kwargs
1056-
)
1057+
)(*args, **kwargs)
10571058
else:
10581059
outputs = self.quantized_call(*args, **kwargs)
10591060
elif self._remat_mode is not None:
1060-
outputs = self.rematerialized_call(self.call, *args, **kwargs)
1061+
outputs = self.rematerialized_call(self.call, *args, **kwargs)(
1062+
*args, **kwargs
1063+
)
10611064
else:
10621065
outputs = self.call(*args, **kwargs)
10631066
if return_losses:
@@ -1611,13 +1614,13 @@ def compute_size(x):
16111614

16121615
# Full rematerialization
16131616
if self._remat_mode.mode == "full":
1614-
return remat.remat(layer_call)(*args, **kwargs)
1617+
return remat.remat(layer_call)
16151618

16161619
# Apply rematerialization to specific layers
16171620
elif self._remat_mode.mode == "list_of_layers" and (
16181621
self.name in self._remat_mode.layer_names
16191622
):
1620-
return remat.remat(layer_call)(*args, **kwargs)
1623+
return remat.remat(layer_call)
16211624

16221625
# Apply rematerialization based on output size threshold
16231626
elif self._remat_mode.mode == "larger_than":
@@ -1629,20 +1632,24 @@ def compute_size(x):
16291632
output_size
16301633
and output_size > self._remat_mode.output_size_threshold
16311634
):
1632-
return remat.remat(layer_call)(*args, **kwargs)
1635+
return remat.remat(layer_call)
16331636
elif self._remat_mode.mode == "activations":
16341637
has_activation = (
16351638
hasattr(self, "activation") and self.activation is not None
16361639
)
16371640
if has_activation:
1638-
not_rematted_activation = self.activation
1639-
try:
1640-
self.activation = remat.remat(not_rematted_activation)
1641-
return layer_call(*args, **kwargs)
1642-
finally:
1643-
self.activation = not_rematted_activation
16441641

1645-
return layer_call(*args, **kwargs)
1642+
@functools.wraps(layer_call)
1643+
def rematerialized_activation_call_wrapper(*args, **kwargs):
1644+
original_activation = self.activation
1645+
self.activation = remat.remat(original_activation)
1646+
try:
1647+
return layer_call(*args, **kwargs)
1648+
finally:
1649+
self.activation = original_activation
1650+
1651+
return rematerialized_activation_call_wrapper
1652+
return layer_call
16461653

16471654

16481655
def is_backend_tensor_or_symbolic(x, allow_none=False):

keras/src/layers/layer_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from keras.src.backend.common import global_state
1717
from keras.src.backend.common.remat import RematScope
1818
from keras.src.models import Model
19+
from keras.src.utils import traceback_utils
1920

2021

2122
class MockRemat:
@@ -237,7 +238,7 @@ def test_functional_model_with_remat(self):
237238
self.skipTest(
238239
"remat is not supported in openvino and numpy backends."
239240
)
240-
# traceback_utils.enable_traceback_filtering()
241+
traceback_utils.enable_traceback_filtering()
241242
mock_remat = MockRemat()
242243
with mock.patch(
243244
"keras.src.backend.common.remat.remat", wraps=mock_remat

keras/src/ops/operation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,15 @@ def __call__(self, *args, **kwargs):
3737
else:
3838
if getattr(self, "_remat_mode", None) is not None:
3939
if getattr(self, "quantization_mode", None) is not None:
40-
call_fn = self.rematerialized_call(self.quantized_call)
40+
call_fn = self.rematerialized_call(
41+
self.quantized_call,
42+
*args,
43+
**kwargs,
44+
)
4145
else:
42-
call_fn = self.rematerialized_call(self.call)
46+
call_fn = self.rematerialized_call(
47+
self.call, *args, **kwargs
48+
)
4349
else:
4450
if getattr(self, "quantization_mode", None) is not None:
4551
call_fn = self.quantized_call
@@ -58,9 +64,11 @@ def __call__(self, *args, **kwargs):
5864
if getattr(self, "quantization_mode", None) is not None:
5965
return self.rematerialized_call(
6066
self.quantized_call, *args, **kwargs
61-
)
67+
)(*args, **kwargs)
6268
else:
63-
return self.rematerialized_call(self.call, *args, **kwargs)
69+
return self.rematerialized_call(self.call, *args, **kwargs)(
70+
*args, **kwargs
71+
)
6472
else:
6573
if getattr(self, "quantization_mode", None) is not None:
6674
return self.quantized_call(*args, **kwargs)

0 commit comments

Comments
 (0)