Following #548 discussion, and while we wait for discrete latent variables, it would be nice to have a Gumbel-Softmax categorical approximation as featured in Pyro. Didn't realize this was the name given to Gumbel-Softmax in Pyro, but hopefully replication might be straight-forward?
numpyro (i.e. Jax) seems uniquely suited for problems involving large discrete structures (e.g. networks), so an ability to recover latent discrete variables (or their approximations) would be fantastic!