Hi, I want to make a PR to disable any codegen to FP64 or S64 or U64 if jax_backend=="Neuron". Neuron hardware does not support 64 bit types, so I want to codegen any 64 bit types in Jax to 32 bit types for Neuron backend only.
We made the same PR to torch-xla - pytorch/xla@7c7ad4e
My question: Where in the Jax library is the HLO type chosen in the codegen? I tried searching but cannot find it.