Skip to content

Commit

Permalink
[MPS] Casting int64 to int32 for reduction ops and raise warning. (py…
Browse files Browse the repository at this point in the history
…torch#94484)

Currently casting it as a workaround till we have full support in OS.
Fixes #pytorch#88319 (comment)

Pull Request resolved: pytorch#94484
Approved by: https://github.com/razarmehr
  • Loading branch information
kulinseth authored and pytorchmergebot committed Feb 10, 2023
1 parent 715f373 commit 8dbe63c
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ Tensor std_mps(
(const Tensor& input_t,
MPSReductionType reduction_type,
const std::string& func_name) {
TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support min/max ops with int64 input");
TORCH_WARN_ONCE(input_t.scalar_type() != ScalarType::Long, "MPS: no support for int64 min/max ops, casting it to int32");

using CachedGraph = MPSUnaryCachedGraph;

Expand Down Expand Up @@ -1280,6 +1280,7 @@ Tensor std_mps(

MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* castInputTensor = nil;
MPSGraphTensor* castOutputTensor = nil;

if (input_t.scalar_type() != ScalarType::Float &&
input_t.scalar_type() != ScalarType::Int &&
Expand All @@ -1302,8 +1303,15 @@ Tensor std_mps(
name:nil];
}

if(input_t.scalar_type() == ScalarType::Long) {
castOutputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeInt64
name:@"castInputTensor"];
} else {
castOutputTensor = outputTensor;
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->outputTensor_ = castOutputTensor;
}
return newCachedGraph;
});
Expand Down

0 comments on commit 8dbe63c

Please sign in to comment.