Skip to content

Conversation

@Grvzard
Copy link
Contributor

@Grvzard Grvzard commented Jun 10, 2024

KerasTensor and the axis is None
Treat an input vector of size 1 as a scalar. Make the behavior consistent with other backends.
Closes #19821

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

elif repeats_size != x_flatten_size:
raise ValueError(
"Size of `repeats` and "
"dimensions of `x` after flattening should be compatible"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print the repeats_size and x_flatten_size as part of the error message, to make it more debuggable.

elif size_on_ax != repeats_size:
raise ValueError(
"Size of `repeats` and "
f"dimensions of `axis {self.axis} of x` should be compatible"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print repeats and x.shape here

@codecov-commenter
Copy link

codecov-commenter commented Jun 11, 2024

Codecov Report

Attention: Patch coverage is 88.23529% with 2 lines in your changes missing coverage. Please review.

Project coverage is 78.84%. Comparing base (2305fad) to head (3b7be3d).
Report is 3 commits behind head on master.

Files Patch % Lines
keras/src/ops/numpy.py 88.23% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #19826       +/-   ##
===========================================
+ Coverage   56.52%   78.84%   +22.31%     
===========================================
  Files         498      498               
  Lines       45801    45813       +12     
  Branches     8440     8443        +3     
===========================================
+ Hits        25890    36119    +10229     
+ Misses      18330     7993    -10337     
- Partials     1581     1701      +120     
Flag Coverage Δ
keras 78.69% <88.23%> (+22.17%) ⬆️
keras-jax 62.39% <88.23%> (?)
keras-numpy 56.63% <88.23%> (+0.10%) ⬆️
keras-tensorflow 63.68% <88.23%> (?)
keras-torch 62.36% <88.23%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 11, 2024
@fchollet fchollet merged commit 2406037 into keras-team:master Jun 11, 2024
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Jun 11, 2024
@Grvzard Grvzard deleted the fix-19821 branch June 11, 2024 15:56
james77777778 pushed a commit to james77777778/keras that referenced this pull request Jun 15, 2024
Fix `LayerNormalization.get_config` (keras-team#19807)

Propagate kwargs through `keras.ops.isclose` (keras-team#19782)

* propagate kwargs through isclose
this allows passing atol and rtol

* switch isclose **kwargs to explicit kwargs

* reduce line lengths

* fix ops.isclose signature

* fix ops.IsClose compute_output_spec signature

* implement isclose rtol atol equal_nan args for all backends

* shorten line lengths again

* revert using tf.experimental.numpy.isclose
tensorflow version now uses code inspired from tf.experimental.numpy.isclose

* fix lint

* add docs for new parameters

Faster in_top_k implementation for Jax backend (keras-team#19814)

* Faster in_top_k implementation.

* Fix bug in rank computation.

Fix CI

Fix TypeError in `Lambda.from_config` (keras-team#19827)

fixing dmtree.is_nested() and parameterized tree test (keras-team#19822)

Fix `keras.ops.repeat` cannot return an expected shape when `x` is a … (keras-team#19826)

* Fix `keras.ops.repeat` cannot return an expected shape when `x` is a `KerasTensor` and the `axis` is `None`

* Test dynamic is still dynamic after repetition

* Improve error messages

`Metric.variables` is now recursive. (keras-team#19830)

This allows it to surface variables from metrics nested at any depth.

Previously, metrics within metrics within metrics would not have their variables tracked in JAX, causing them to not be updated.

Fix `get_file` when the HTTP response has no `Content-Length` header (keras-team#19833)

Add `ops.switch` (keras-team#19834)

* Add `ops.switch`

* Update tests

* Fix out-of-bound issue

* Revert `torch.cond`

Use `absl.testing.parameterized` for `tree_test.py`. (keras-team#19842)

For consistency, use `absl.testing.parameterized` instead of `parameterized` for `tree_test.py` since that is used for all other tests.

It's one less dependency. It also says `optree` or `dmtree` in each test name.

Make batch norm mask shape error more descriptive (keras-team#19829)

* Made batch norm mask shape error more descriptive

* Added shape info in mask error message to help with degugging

Fix code style

doc: `ops.slice` (keras-team#19843)

corrected the example code in unit_normalization.py (keras-team#19845)

Added missing closing bracket and exact output value in example code after replicating the code.

Adjust code example

Add `training` argument to `Model.compute_loss()`. (keras-team#19840)

This allows models to perform different computations during training and evaluation. For instance, some expensive to compute metrics can be skipped during training and only computed during evaluation.

Note that backwards compatibility with overrides that do not have the `training` argument is maintained.

Fix the compatibility issues of `Orthogonal` and `GRU` (keras-team#19844)

* Add legacy `Orthogonal` class name

* Add legacy `implementation` arg to `GRU`

Fix inconsistent behavior of `losses.sparse_categorical_crossentropy`… (keras-team#19838)

* Fix inconsistent behavior of `losses.sparse_categorical_crossentropy` with and without `ignore_class`

* Test

* chore(format)

* Fix tests in `losses`

Fix bugs with `Mean`, `Accuracy` and `BinaryAccuracy` metrics. (keras-team#19847)

- `reduce_to_samplewise_values` would not reduce `sample_weights` correctly because the number of dimensions of `values` was checked.
- `reduce_to_samplewise_values` needs to explicitely broadcast `sample_weights`. Before, it was implicitly broadcast in the multiplication with `values`. However, the explicit broadcast is needed for the computation of `num_samples` for the averaging to be correct. This causes a bug when `sample_weights` is of rank 2 or more and a broadcast happens when doing the multiplication. This logic existed in `tf_keras`: https://github.com/keras-team/tf-keras/blob/master/tf_keras/metrics/base_metric.py#L508
- `Accuracy` and `BinaryAccuracy` were doing a mean reduction too early, before multiplying by `sample_weights`. This matters when the rank of `sample_weights` is the same as `y_true` and `y_pred`.

Add tests for `DTypePolicyMap`

Fix test

Update the logic of `default_policy`

Improve serialization of `DTypePolicyMap`

Improve `__repr__` and `__eq__`

Add `custom_gradient` for the numpy backend (keras-team#19849)

fix variable name when add in init function (keras-team#19853)

Address comments
james77777778 pushed a commit to james77777778/keras that referenced this pull request Jun 15, 2024
Introduce `DTypePolicyMap`

Fix `LayerNormalization.get_config` (keras-team#19807)

Propagate kwargs through `keras.ops.isclose` (keras-team#19782)

* propagate kwargs through isclose
this allows passing atol and rtol

* switch isclose **kwargs to explicit kwargs

* reduce line lengths

* fix ops.isclose signature

* fix ops.IsClose compute_output_spec signature

* implement isclose rtol atol equal_nan args for all backends

* shorten line lengths again

* revert using tf.experimental.numpy.isclose
tensorflow version now uses code inspired from tf.experimental.numpy.isclose

* fix lint

* add docs for new parameters

Faster in_top_k implementation for Jax backend (keras-team#19814)

* Faster in_top_k implementation.

* Fix bug in rank computation.

Fix CI

Fix TypeError in `Lambda.from_config` (keras-team#19827)

fixing dmtree.is_nested() and parameterized tree test (keras-team#19822)

Fix `keras.ops.repeat` cannot return an expected shape when `x` is a … (keras-team#19826)

* Fix `keras.ops.repeat` cannot return an expected shape when `x` is a `KerasTensor` and the `axis` is `None`

* Test dynamic is still dynamic after repetition

* Improve error messages

`Metric.variables` is now recursive. (keras-team#19830)

This allows it to surface variables from metrics nested at any depth.

Previously, metrics within metrics within metrics would not have their variables tracked in JAX, causing them to not be updated.

Fix `get_file` when the HTTP response has no `Content-Length` header (keras-team#19833)

Add `ops.switch` (keras-team#19834)

* Add `ops.switch`

* Update tests

* Fix out-of-bound issue

* Revert `torch.cond`

Use `absl.testing.parameterized` for `tree_test.py`. (keras-team#19842)

For consistency, use `absl.testing.parameterized` instead of `parameterized` for `tree_test.py` since that is used for all other tests.

It's one less dependency. It also says `optree` or `dmtree` in each test name.

Make batch norm mask shape error more descriptive (keras-team#19829)

* Made batch norm mask shape error more descriptive

* Added shape info in mask error message to help with degugging

Fix code style

doc: `ops.slice` (keras-team#19843)

corrected the example code in unit_normalization.py (keras-team#19845)

Added missing closing bracket and exact output value in example code after replicating the code.

Adjust code example

Add `training` argument to `Model.compute_loss()`. (keras-team#19840)

This allows models to perform different computations during training and evaluation. For instance, some expensive to compute metrics can be skipped during training and only computed during evaluation.

Note that backwards compatibility with overrides that do not have the `training` argument is maintained.

Fix the compatibility issues of `Orthogonal` and `GRU` (keras-team#19844)

* Add legacy `Orthogonal` class name

* Add legacy `implementation` arg to `GRU`

Fix inconsistent behavior of `losses.sparse_categorical_crossentropy`… (keras-team#19838)

* Fix inconsistent behavior of `losses.sparse_categorical_crossentropy` with and without `ignore_class`

* Test

* chore(format)

* Fix tests in `losses`

Fix bugs with `Mean`, `Accuracy` and `BinaryAccuracy` metrics. (keras-team#19847)

- `reduce_to_samplewise_values` would not reduce `sample_weights` correctly because the number of dimensions of `values` was checked.
- `reduce_to_samplewise_values` needs to explicitely broadcast `sample_weights`. Before, it was implicitly broadcast in the multiplication with `values`. However, the explicit broadcast is needed for the computation of `num_samples` for the averaging to be correct. This causes a bug when `sample_weights` is of rank 2 or more and a broadcast happens when doing the multiplication. This logic existed in `tf_keras`: https://github.com/keras-team/tf-keras/blob/master/tf_keras/metrics/base_metric.py#L508
- `Accuracy` and `BinaryAccuracy` were doing a mean reduction too early, before multiplying by `sample_weights`. This matters when the rank of `sample_weights` is the same as `y_true` and `y_pred`.

Add tests for `DTypePolicyMap`

Fix test

Update the logic of `default_policy`

Improve serialization of `DTypePolicyMap`

Improve `__repr__` and `__eq__`

Add `custom_gradient` for the numpy backend (keras-team#19849)

fix variable name when add in init function (keras-team#19853)

Address comments

Update docstrings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: Assigned Reviewer

Development

Successfully merging this pull request may close these issues.

keras.ops.repeat cannot return an exptected shape when x is a KerasTensor and the axis is None

4 participants