File tree Expand file tree Collapse file tree 3 files changed +20
-5
lines changed Expand file tree Collapse file tree 3 files changed +20
-5
lines changed Original file line number Diff line number Diff line change @@ -116,4 +116,16 @@ bool CheckTpuDevice(XlaDeviceType hw_type) {
116116 return false ;
117117}
118118
119+ bool CheckNeuronDevice (XlaDeviceType hw_type) {
120+ if (hw_type == XlaDeviceType::NEURON) {
121+ return true ;
122+ }
123+
124+ std::string pjrt_device = runtime::sys_util::GetEnvString (" PJRT_DEVICE" , " " );
125+ if (hw_type == XlaDeviceType::SPMD) {
126+ return pjrt_device == " NEURON" ;
127+ }
128+ return false ;
129+ }
130+
119131} // namespace torch_xla
Original file line number Diff line number Diff line change @@ -57,6 +57,9 @@ bool GetLockSpmdConfig();
5757// TODO(yeounoh) - see if we need to check for AOT compilation device type.
5858bool CheckTpuDevice (XlaDeviceType hw_type);
5959
60+ // Return true if the physical device type is NEURON.
61+ bool CheckNeuronDevice (XlaDeviceType hw_type);
62+
6063} // namespace torch_xla
6164
6265#endif // XLA_TORCH_XLA_CSRC_DEVICE_H_
Original file line number Diff line number Diff line change @@ -129,19 +129,19 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
129129 if (UseBF16 ()) {
130130 return xla::PrimitiveType::BF16;
131131 }
132- if (DowncastBF16 () || hw_type == XlaDeviceType::NEURON ) {
132+ if (DowncastBF16 () || CheckNeuronDevice ( hw_type) ) {
133133 return xla::PrimitiveType::F32;
134134 }
135135 return xla::PrimitiveType::F64;
136136 case xla::PrimitiveType::F32:
137137 return UseBF16 () || DowncastBF16 () ? xla::PrimitiveType::BF16
138138 : xla::PrimitiveType::F32;
139139 case xla::PrimitiveType::U16:
140- return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
141- : xla::PrimitiveType::U32 ;
140+ return CheckNeuronDevice ( hw_type) ? xla::PrimitiveType::U32
141+ : xla::PrimitiveType::U16 ;
142142 case xla::PrimitiveType::S16:
143- return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
144- : xla::PrimitiveType::S32 ;
143+ return CheckNeuronDevice ( hw_type) ? xla::PrimitiveType::S32
144+ : xla::PrimitiveType::S16 ;
145145 case xla::PrimitiveType::S64:
146146 return xla::PrimitiveType::S64;
147147 case xla::PrimitiveType::U64:
You can’t perform that action at this time.
0 commit comments