Skip to content

Disable 64 bit codegen to HLO for Neuron backend #24682

@patrick-toulme

Description

@patrick-toulme

Hi, I want to make a PR to disable any codegen to FP64 or S64 or U64 if jax_backend=="Neuron". Neuron hardware does not support 64 bit types, so I want to codegen any 64 bit types in Jax to 32 bit types for Neuron backend only.

We made the same PR to torch-xla - pytorch/xla@7c7ad4e

My question: Where in the Jax library is the HLO type chosen in the codegen? I tried searching but cannot find it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions