Skip to content

Find a way to appropriately replace jacobian calculation with TF support using tf.GradientTape.jacobian #14

Closed
@DarrenZhang01

Description

As titled.

empirical.py", line 237, in ntk_fn
    ntk = jacobian(delta_vjp_jvp)(fx_dummy)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 633, in jacfun
    tree_map(partial(_check_input_dtype_jacrev, holomorphic), dyn_args)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/tree_util.py", line 164, in tree_map
    return treedef.unflatten(map(f, leaves))
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 500, in _check_input_dtype_revderiv
    _check_arg(x)
  File "/Library/Frameworks/Python.framework/Versions/3.8/lib/python3.8/site-packages/jax/api.py", line 1681, in _check_arg
    raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'ndarray<Tensor("ones:0", shape=(4, 1), dtype=float64)>' of type <class 'trax.tf_numpy.numpy_impl.arrays.ndarray'> is not a valid JAX type

Metadata

Assignees

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions