Skip to content

Support for exporting models in the ONNX format #4430

Open
@Artoriuz

Description

Hi!

I know that this is something that have been asked before, but I just wanted to ask again: Is there any plan to eventually support exporting Flax models in the ONNX format?

ONNX has become a very popular way of distributing models as some kind of lingua franca that is supported by various inference engines like TensorRT, NCNN, OpenVINO, ONNX Runtime, et cetera...

The current workaround is to convert your Flax model to TensorFlow first using jax2tf, and then converting to ONNX from that using tf2onnx. While this works, the resulting ONNX models often contain various unnecessary steps and even simple things aren't mapped to the expected corresponding operations. I've also only been able to get this to work with enable_xla=False in the jax2tf conversion, which has been deprecated.

I understand that there's an argument to be made that perhaps it would make more sense to have this at the JAX level instead, but honestly I don't think ONNX is very popular outside of ML and doing it at the Flax level would maybe make it easier for the operations to be mapped 1:1.

FWIW, Equinox supports this by bridging through TF first as well, so that seems to be the status quo everywhere.

Thanks in advance!

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions