Skip to content

Commit 6b19334

Browse files
Grvzardjames77777778
authored andcommitted
Introduce DTypePolicyMap
Introduce `DTypePolicyMap` Fix `LayerNormalization.get_config` (keras-team#19807) Propagate kwargs through `keras.ops.isclose` (keras-team#19782) * propagate kwargs through isclose this allows passing atol and rtol * switch isclose **kwargs to explicit kwargs * reduce line lengths * fix ops.isclose signature * fix ops.IsClose compute_output_spec signature * implement isclose rtol atol equal_nan args for all backends * shorten line lengths again * revert using tf.experimental.numpy.isclose tensorflow version now uses code inspired from tf.experimental.numpy.isclose * fix lint * add docs for new parameters Faster in_top_k implementation for Jax backend (keras-team#19814) * Faster in_top_k implementation. * Fix bug in rank computation. Fix CI Fix TypeError in `Lambda.from_config` (keras-team#19827) fixing dmtree.is_nested() and parameterized tree test (keras-team#19822) Fix `keras.ops.repeat` cannot return an expected shape when `x` is a … (keras-team#19826) * Fix `keras.ops.repeat` cannot return an expected shape when `x` is a `KerasTensor` and the `axis` is `None` * Test dynamic is still dynamic after repetition * Improve error messages `Metric.variables` is now recursive. (keras-team#19830) This allows it to surface variables from metrics nested at any depth. Previously, metrics within metrics within metrics would not have their variables tracked in JAX, causing them to not be updated. Fix `get_file` when the HTTP response has no `Content-Length` header (keras-team#19833) Add `ops.switch` (keras-team#19834) * Add `ops.switch` * Update tests * Fix out-of-bound issue * Revert `torch.cond` Use `absl.testing.parameterized` for `tree_test.py`. (keras-team#19842) For consistency, use `absl.testing.parameterized` instead of `parameterized` for `tree_test.py` since that is used for all other tests. It's one less dependency. It also says `optree` or `dmtree` in each test name. Make batch norm mask shape error more descriptive (keras-team#19829) * Made batch norm mask shape error more descriptive * Added shape info in mask error message to help with degugging Fix code style doc: `ops.slice` (keras-team#19843) corrected the example code in unit_normalization.py (keras-team#19845) Added missing closing bracket and exact output value in example code after replicating the code. Adjust code example Add `training` argument to `Model.compute_loss()`. (keras-team#19840) This allows models to perform different computations during training and evaluation. For instance, some expensive to compute metrics can be skipped during training and only computed during evaluation. Note that backwards compatibility with overrides that do not have the `training` argument is maintained. Fix the compatibility issues of `Orthogonal` and `GRU` (keras-team#19844) * Add legacy `Orthogonal` class name * Add legacy `implementation` arg to `GRU` Fix inconsistent behavior of `losses.sparse_categorical_crossentropy`… (keras-team#19838) * Fix inconsistent behavior of `losses.sparse_categorical_crossentropy` with and without `ignore_class` * Test * chore(format) * Fix tests in `losses` Fix bugs with `Mean`, `Accuracy` and `BinaryAccuracy` metrics. (keras-team#19847) - `reduce_to_samplewise_values` would not reduce `sample_weights` correctly because the number of dimensions of `values` was checked. - `reduce_to_samplewise_values` needs to explicitely broadcast `sample_weights`. Before, it was implicitly broadcast in the multiplication with `values`. However, the explicit broadcast is needed for the computation of `num_samples` for the averaging to be correct. This causes a bug when `sample_weights` is of rank 2 or more and a broadcast happens when doing the multiplication. This logic existed in `tf_keras`: https://github.com/keras-team/tf-keras/blob/master/tf_keras/metrics/base_metric.py#L508 - `Accuracy` and `BinaryAccuracy` were doing a mean reduction too early, before multiplying by `sample_weights`. This matters when the rank of `sample_weights` is the same as `y_true` and `y_pred`. Add tests for `DTypePolicyMap` Fix test Update the logic of `default_policy` Improve serialization of `DTypePolicyMap` Improve `__repr__` and `__eq__` Add `custom_gradient` for the numpy backend (keras-team#19849) fix variable name when add in init function (keras-team#19853) Address comments Update docstrings
1 parent 7a4eb67 commit 6b19334

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1240
-205
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from keras.src.ops.core import slice
2323
from keras.src.ops.core import slice_update
2424
from keras.src.ops.core import stop_gradient
25+
from keras.src.ops.core import switch
2526
from keras.src.ops.core import unstack
2627
from keras.src.ops.core import vectorized_map
2728
from keras.src.ops.core import while_loop

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from keras.src.ops.core import slice
2323
from keras.src.ops.core import slice_update
2424
from keras.src.ops.core import stop_gradient
25+
from keras.src.ops.core import switch
2526
from keras.src.ops.core import unstack
2627
from keras.src.ops.core import vectorized_map
2728
from keras.src.ops.core import while_loop

keras/src/backend/jax/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ def slice_update(inputs, start_indices, updates):
287287
return jax.lax.dynamic_update_slice(inputs, updates, start_indices)
288288

289289

290+
def switch(index, branches, *operands):
291+
return jax.lax.switch(index, branches, *operands)
292+
293+
290294
def while_loop(
291295
cond,
292296
body,

keras/src/backend/jax/math.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def top_k(x, k, sorted=True):
4040

4141

4242
def in_top_k(targets, predictions, k):
43-
targets = targets[..., None]
44-
topk_values = top_k(predictions, k)[0]
45-
targets_values = jnp.take_along_axis(predictions, targets, axis=-1)
46-
mask = targets_values >= topk_values
47-
return jnp.any(mask, axis=1)
43+
preds_at_label = jnp.take_along_axis(
44+
predictions, jnp.expand_dims(targets, axis=-1), axis=-1
45+
)
46+
rank = 1 + jnp.sum(jnp.greater(predictions, preds_at_label), axis=-1)
47+
return jnp.less_equal(rank, k)
4848

4949

5050
def logsumexp(x, axis=None, keepdims=False):

keras/src/backend/jax/numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,10 @@ def imag(x):
598598
return jnp.imag(x)
599599

600600

601-
def isclose(x1, x2):
601+
def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
602602
x1 = convert_to_tensor(x1)
603603
x2 = convert_to_tensor(x2)
604-
return jnp.isclose(x1, x2)
604+
return jnp.isclose(x1, x2, rtol, atol, equal_nan)
605605

606606

607607
@sparse.densifying_unary

keras/src/backend/jax/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def compute_loss_and_updates(
6363
y=y,
6464
y_pred=y_pred,
6565
sample_weight=sample_weight,
66+
training=training,
6667
)
6768
if losses:
6869
self._losses_override.clear()

keras/src/backend/numpy/core.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24

35
from keras.src import tree
@@ -241,6 +243,12 @@ def slice_update(inputs, start_indices, updates):
241243
return inputs
242244

243245

246+
def switch(index, branches, *operands):
247+
index = convert_to_tensor(index, "int32")
248+
index = np.clip(index, 0, len(branches) - 1)
249+
return branches[index](*operands)
250+
251+
244252
def while_loop(
245253
cond,
246254
body,
@@ -279,7 +287,21 @@ def unstack(x, num=None, axis=0):
279287
return [x[i] for i in range(x.shape[0])]
280288

281289

282-
def custom_gradient(fun):
283-
raise NotImplementedError(
284-
"`custom_gradient` is not supported with numpy backend"
285-
)
290+
class custom_gradient:
291+
"""Decorator for custom gradients.
292+
293+
Args:
294+
fun: Forward pass function.
295+
"""
296+
297+
def __init__(self, fun):
298+
warnings.warn(
299+
"`custom_gradient` for the numpy backend acts as a pass-through to "
300+
"support the forward pass. No gradient computation or modification "
301+
"takes place."
302+
)
303+
self.fun = fun
304+
305+
def __call__(self, *args, **kwargs):
306+
outputs, _ = self.fun(*args, **kwargs)
307+
return outputs

keras/src/backend/numpy/numpy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def matmul(x1, x2):
6565
dtype = "int32"
6666
else:
6767
dtype = dtypes.result_type(x1.dtype, x2.dtype)
68+
x1 = x1.astype(dtype)
69+
x2 = x2.astype(dtype)
6870
return np.matmul(x1, x2).astype(dtype)
6971

7072

@@ -505,8 +507,8 @@ def imag(x):
505507
return np.imag(x)
506508

507509

508-
def isclose(x1, x2):
509-
return np.isclose(x1, x2)
510+
def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
511+
return np.isclose(x1, x2, rtol, atol, equal_nan)
510512

511513

512514
def isfinite(x):

keras/src/backend/numpy/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def test_step(self, data):
2828
y_pred = self(x, training=False)
2929
else:
3030
y_pred = self(x)
31-
loss = self.compute_loss(
32-
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight
31+
loss = self._compute_loss(
32+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
3333
)
3434
self._loss_tracker.update_state(
3535
loss, sample_weight=tree.flatten(x)[0].shape[0]

keras/src/backend/tensorflow/core.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,19 @@ def slice_update(inputs, start_indices, updates):
355355
return dynamic_update_slice(inputs, updates, start_indices)
356356

357357

358+
def switch(index, branches, *operands):
359+
index = convert_to_tensor(index, "int32")
360+
index = tf.clip_by_value(index, 0, len(branches) - 1)
361+
362+
# Workaround to deal with python closures. More details:
363+
# https://github.com/tensorflow/tensorflow/issues/8776#issuecomment-311383887
364+
def gen_fn(i):
365+
return lambda: branches[i](*operands)
366+
367+
branch_fns = [gen_fn(i) for i in range(len(branches))]
368+
return tf.switch_case(index, branch_fns)
369+
370+
358371
def while_loop(
359372
cond,
360373
body,

0 commit comments

Comments
 (0)