Skip to content

Commit 418c0c0

Browse files
committed
[Neuron] Fix the XLADevice Neuron mappings for SPMD downcasts
1 parent 5dbdb8d commit 418c0c0

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

torch_xla/csrc/device.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,16 @@ bool CheckTpuDevice(XlaDeviceType hw_type) {
116116
return false;
117117
}
118118

119+
bool CheckNeuronDevice(XlaDeviceType hw_type) {
120+
if (hw_type == XlaDeviceType::NEURON) {
121+
return true;
122+
}
123+
124+
std::string pjrt_device = runtime::sys_util::GetEnvString("PJRT_DEVICE", "");
125+
if (hw_type == XlaDeviceType::SPMD) {
126+
return pjrt_device == "NEURON";
127+
}
128+
return false;
129+
}
130+
119131
} // namespace torch_xla

torch_xla/csrc/device.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ bool GetLockSpmdConfig();
5757
// TODO(yeounoh) - see if we need to check for AOT compilation device type.
5858
bool CheckTpuDevice(XlaDeviceType hw_type);
5959

60+
// Return true if the physical device type is NEURON.
61+
bool CheckNeuronDevice(XlaDeviceType hw_type);
62+
6063
} // namespace torch_xla
6164

6265
#endif // XLA_TORCH_XLA_CSRC_DEVICE_H_

torch_xla/csrc/dtype.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
129129
if (UseBF16()) {
130130
return xla::PrimitiveType::BF16;
131131
}
132-
if (DowncastBF16() || hw_type == XlaDeviceType::NEURON) {
132+
if (DowncastBF16() || CheckNeuronDevice(hw_type)) {
133133
return xla::PrimitiveType::F32;
134134
}
135135
return xla::PrimitiveType::F64;
136136
case xla::PrimitiveType::F32:
137137
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
138138
: xla::PrimitiveType::F32;
139139
case xla::PrimitiveType::U16:
140-
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
141-
: xla::PrimitiveType::U32;
140+
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
141+
: xla::PrimitiveType::U16;
142142
case xla::PrimitiveType::S16:
143-
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
144-
: xla::PrimitiveType::S32;
143+
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
144+
: xla::PrimitiveType::S16;
145145
case xla::PrimitiveType::S64:
146146
return xla::PrimitiveType::S64;
147147
case xla::PrimitiveType::U64:

0 commit comments

Comments
 (0)