Open
Description
Several of our APIs (jax.linearize
, AOT lowering, jax.make_jaxpr
, ...) emphasize duck-typed arguments that correspond to abstract values/arrays. We stress in documentation that any object with shape
and dtype
(etc.) is acceptable. But then we annotate these APIs with the concrete jax.ShapeDtypeStruct
, contradicting the docs.
One way forward may be to define a Protocol type, use it as the annotation, and continue to expose jax.ShapeDtypeStruct
for convenience.
(It's a bit unclear to me what to name this protocol. ShapeDtypeStruct
itself is already arguably a misnomer, since there are other optional fields there too.)