https://towardsdatascience.com/deep-learning-with-jax-and-elegy-c0765e3ec31a
This is a simple example on how to build a image classifier based on JAX and Elegy.
I am using transfer learning through bottleneck features, which gives me the ability to use any architecture as a feature extractor and build the upper layers with JAX.
Some of the more famous architecture, are not yet implemented in JAX, so this is an example on how to use a Keras/Tensorflow feature extractor and use it to feed JAX model.