Replies: 1 comment
-
We have deprecated converting JAX to TF ops, and we plan to remove that. In particular the option that you see to use (enable_xla=False) was really never finished, and never got to the point where we could guarantee fidelity. That conversion was useful while the TF-independent toolchain matured, but not we can do almost everything we did with TF graph without the conversion. StableHLO is now the only recommended and supported export from JAX. This means, that sadly, you cannot continue to rely on the conversion to ONNX by way of TF. The real solution here would be to write a StableHLO to ONNX converter. But this could be a very large task. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
The standard approach for exporting Flax (and other libraries') models to ONNX has been to first create a TensorFlow function using jax2tf.convert, and then export this function to ONNX. With the deprecation (and now lack of support) for enable_xla, does this mean that exporting models trained in Jax to ONNX is no longer possible?
This would be a significant issue for me, as I work on a team where the preferred format for sharing trained models is ONNX. Are there any plans to provide a workaround or an alternative method to do so?
Thank you
Note: this is related to a similar discussion in the flax repository but I could not find similar discussion here.
google/flax#4430
Beta Was this translation helpful? Give feedback.
All reactions