jax-vs-pytorch Comparisons between PyTorch and JAX (Flax) Todo Training MLP on MNIST Training CNN on CIFAR-10 Training VAE Training GAN