Skip to content

Commit 11be99a

Browse files
Add custom_gradient for the numpy backend (#19849)
1 parent a4e8554 commit 11be99a

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

keras/src/backend/numpy/core.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24

35
from keras.src import tree
@@ -285,7 +287,21 @@ def unstack(x, num=None, axis=0):
285287
return [x[i] for i in range(x.shape[0])]
286288

287289

288-
def custom_gradient(fun):
289-
raise NotImplementedError(
290-
"`custom_gradient` is not supported with numpy backend"
291-
)
290+
class custom_gradient:
291+
"""Decorator for custom gradients.
292+
293+
Args:
294+
fun: Forward pass function.
295+
"""
296+
297+
def __init__(self, fun):
298+
warnings.warn(
299+
"`custom_gradient` for the numpy backend acts as a pass-through to "
300+
"support the forward pass. No gradient computation or modification "
301+
"takes place."
302+
)
303+
self.fun = fun
304+
305+
def __call__(self, *args, **kwargs):
306+
outputs, _ = self.fun(*args, **kwargs)
307+
return outputs

keras/src/backend/numpy/numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def matmul(x1, x2):
6565
dtype = "int32"
6666
else:
6767
dtype = dtypes.result_type(x1.dtype, x2.dtype)
68+
x1 = x1.astype(dtype)
69+
x2 = x2.astype(dtype)
6870
return np.matmul(x1, x2).astype(dtype)
6971

7072

keras/src/layers/core/dense_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ def test_enable_lora_when_already_enabled(self):
334334

335335
# Test quantization-related (int8 and float8) methods
336336

337-
@pytest.mark.requires_trainable_backend
338337
def test_quantize_int8(self):
339338
layer = layers.Dense(units=16)
340339
layer.build((None, 8))
@@ -764,7 +763,6 @@ def test_quantize_float8_fitting(self):
764763
len(model.non_trainable_weights),
765764
)
766765

767-
@pytest.mark.requires_trainable_backend
768766
def test_quantize_float8_inference(self):
769767
config = dict(units=16)
770768
layer = layers.Dense(**config)

keras/src/layers/core/einsum_dense_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,6 @@ def test_lora_rank_argument(self):
382382

383383
# Test quantization-related (int8 and float8) methods
384384

385-
@pytest.mark.requires_trainable_backend
386385
def test_quantize_int8(self):
387386
layer = layers.EinsumDense(
388387
equation="ab,bcd->acd",
@@ -471,7 +470,6 @@ def test_quantize_int8(self):
471470
("btd,ndh->btnh", "btd,ndh->btnh", (None, 2, 8), (1, 2, 4)),
472471
("btd,df->btf", "btd,df->btf", (None, 4), (1, 2, 4)),
473472
)
474-
@pytest.mark.requires_trainable_backend
475473
def test_quantize_int8_with_specific_equations(
476474
self, equation, output_shape, input_shape
477475
):
@@ -903,7 +901,6 @@ def test_quantize_float8_fitting(self):
903901
len(model.non_trainable_weights),
904902
)
905903

906-
@pytest.mark.requires_trainable_backend
907904
def test_quantize_float8_inference(self):
908905
config = dict(
909906
equation="ab,bcd->acd",

0 commit comments

Comments
 (0)