Description
Hi!
I know that this is something that have been asked before, but I just wanted to ask again: Is there any plan to eventually support exporting Flax models in the ONNX format?
ONNX has become a very popular way of distributing models as some kind of lingua franca that is supported by various inference engines like TensorRT, NCNN, OpenVINO, ONNX Runtime, et cetera...
The current workaround is to convert your Flax model to TensorFlow first using jax2tf
, and then converting to ONNX from that using tf2onnx
. While this works, the resulting ONNX models often contain various unnecessary steps and even simple things aren't mapped to the expected corresponding operations. I've also only been able to get this to work with enable_xla=False
in the jax2tf
conversion, which has been deprecated.
I understand that there's an argument to be made that perhaps it would make more sense to have this at the JAX level instead, but honestly I don't think ONNX is very popular outside of ML and doing it at the Flax level would maybe make it easier for the operations to be mapped 1:1.
FWIW, Equinox supports this by bridging through TF first as well, so that seems to be the status quo everywhere.
Thanks in advance!
Activity