You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently the inputs to `from_model_output` are not typed. However,
these functions cannot accept arbitrary inputs, they need to be a value
convertable to a `jax.Array`. This change fixes this so that:
- `from_model_output` takes in types of `Array` or `ArrayLike`
- Removes use of `jnp.array` as a type as it's equivalent to `Any`
- Makes members of Metric classes have type `Array`
- Moves mask checking code into its own function
While we could make everything use `Array` (instead of `ArrayLike`),
this would break code like:
```
@flax.struct.dataclass
class Collection(metrics.Collection):
train_accuracy: metrics.Accuracy
learning_rate: metrics.LastValue.from_output("learning_rate")
Collection.gather_from_model_output(learning_rate=0.02, ...)
```
which seems undesirable.
Note that `count` and `value` for `LastValue` have type `ArrayLike`,
as this code needs to support passing a plain number for `value` or
`count`. Also, the base `Metric.compute()` method has type `Any`,
because some metrics return `Array` while others use `dict[str, Array]`.
PiperOrigin-RevId: 529227218
0 commit comments