Closed
Description
As titled.
empirical.py", line 237, in ntk_fn
ntk = jacobian(delta_vjp_jvp)(fx_dummy)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 633, in jacfun
tree_map(partial(_check_input_dtype_jacrev, holomorphic), dyn_args)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/tree_util.py", line 164, in tree_map
return treedef.unflatten(map(f, leaves))
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 500, in _check_input_dtype_revderiv
_check_arg(x)
File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 1681, in _check_arg
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'ndarray<Tensor("ones:0", shape=(4, 1), dtype=float64)>' of type <class 'trax.tf_numpy.numpy_impl.arrays.ndarray'> is not a valid JAX type