Skip to content

Commit 7c7ad4e

Browse files
authored
[Neuron] Fix the XLADevice Neuron mappings for SPMD downcasts (#8335)
1 parent 1bac062 commit 7c7ad4e

File tree

6 files changed

+23
-8
lines changed

6 files changed

+23
-8
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3667,7 +3667,7 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward(
36673667
// our XLA lowering.
36683668
XlaDeviceType hw_type =
36693669
static_cast<XlaDeviceType>(grad_output_tensor->GetDevice().type());
3670-
if (!CheckTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON) {
3670+
if (!CheckTpuDevice(hw_type) && !CheckNeuronDevice(hw_type)) {
36713671
return at::native::call_fallback_fn<
36723672
&xla_fallback, ATEN_OP(upsample_nearest2d_backward)>::call(grad_output,
36733673
output_size,

torch_xla/csrc/data_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ bool IsSparseGather(const xla::Shape& input_shape,
3232
// to avoid gather on a single float on TPU.
3333
XlaDeviceType hw_type =
3434
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
35-
if (CheckTpuDevice(hw_type) || hw_type == XlaDeviceType::NEURON) {
35+
if (CheckTpuDevice(hw_type) || CheckNeuronDevice(hw_type)) {
3636
// XLA_DENSE_GATHER_FACTOR can be used to finely control the
3737
// sparsity check.
3838
static int dense_gather_factor =

torch_xla/csrc/device.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,16 @@ bool CheckTpuDevice(XlaDeviceType hw_type) {
116116
return false;
117117
}
118118

119+
bool CheckNeuronDevice(XlaDeviceType hw_type) {
120+
if (hw_type == XlaDeviceType::NEURON) {
121+
return true;
122+
}
123+
124+
std::string pjrt_device = runtime::sys_util::GetEnvString("PJRT_DEVICE", "");
125+
if (hw_type == XlaDeviceType::SPMD) {
126+
return pjrt_device == "NEURON";
127+
}
128+
return false;
129+
}
130+
119131
} // namespace torch_xla

torch_xla/csrc/device.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ bool GetLockSpmdConfig();
5757
// TODO(yeounoh) - see if we need to check for AOT compilation device type.
5858
bool CheckTpuDevice(XlaDeviceType hw_type);
5959

60+
// Return true if the physical device type is NEURON.
61+
bool CheckNeuronDevice(XlaDeviceType hw_type);
62+
6063
} // namespace torch_xla
6164

6265
#endif // XLA_TORCH_XLA_CSRC_DEVICE_H_

torch_xla/csrc/dtype.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,19 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
129129
if (UseBF16()) {
130130
return xla::PrimitiveType::BF16;
131131
}
132-
if (DowncastBF16() || hw_type == XlaDeviceType::NEURON) {
132+
if (DowncastBF16() || CheckNeuronDevice(hw_type)) {
133133
return xla::PrimitiveType::F32;
134134
}
135135
return xla::PrimitiveType::F64;
136136
case xla::PrimitiveType::F32:
137137
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
138138
: xla::PrimitiveType::F32;
139139
case xla::PrimitiveType::U16:
140-
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
141-
: xla::PrimitiveType::U32;
140+
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U32
141+
: xla::PrimitiveType::U16;
142142
case xla::PrimitiveType::S16:
143-
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
144-
: xla::PrimitiveType::S32;
143+
return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32
144+
: xla::PrimitiveType::S16;
145145
case xla::PrimitiveType::S64:
146146
return xla::PrimitiveType::S64;
147147
case xla::PrimitiveType::U64:

torch_xla/csrc/resize_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input,
271271

272272
XlaDeviceType hw_type =
273273
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
274-
if (CheckTpuDevice(hw_type) || hw_type == XlaDeviceType::NEURON) {
274+
if (CheckTpuDevice(hw_type) || CheckNeuronDevice(hw_type)) {
275275
// TPU uses custom call implementation
276276
resized =
277277
xla::CustomCall(input.builder(), target, {tinput}, resized_shape,

0 commit comments

Comments
 (0)