Description
Let's collect a bunch of code samples from array-api libraries that array-api-typing
could improve. That should help give us a better idea of what is needed. They should also be useful for (integration) testing.
Problems
scipy.cluster.vq
Here's an example from the scipy
docs
import torch
from scipy.cluster.vq import vq
code_book = torch.tensor([[1., 1., 1.],
[2., 2., 2.]])
features = torch.tensor([[1.9, 2.3, 1.7],
[1.5, 2.5, 2.2],
[0.8, 0.6, 1.7]])
code, dist = vq(features, code_book)
print(code) # tensor([1, 1, 0], dtype=torch.int32)
print(dist) # tensor([0.4359, 0.7348, 0.8307])
There is currently no good way to annotate this behavior in scipy-stubs
. Currently, with scipy-stubs==1.15.3.0
, code
and dist
are inferred as 1d numpy arrays of int and float dtypes, respectively. The reason for this is the torch.tensor.__array__
method, cuasing it to be interpreted as a numpy "array-like".
efax
By @NeilGirdhar from #22 (comment)
Consider practically any function in my efax library. For example, we have:
def _r_s_mu(self) -> tuple[JaxComplexArray, JaxRealArray, JaxComplexArray]:
xp = array_namespace(self)
r = -self.pseudo_precision / self.negative_precision
s = xp.reciprocal((abs_square(r) - 1.0) * self.negative_precision)
k = self.pseudo_precision / self.negative_precision
l_eta = 0.5 * self.mean_times_precision / ((abs_square(k) - 1.0) * self.negative_precision)
mu = xp.conj(l_eta) - xp.conj(self.pseudo_precision / self.negative_precision) * l_eta
return r, s, mu
This is barely type-checked. It would be nice if xp: ArrayNamespace, r, s, k, l_eta, mu: Array, etc.