Skip to content

Commit

Permalink
Added implementation and tests for MPS Hardswish (pytorch#87952)
Browse files Browse the repository at this point in the history
## What?
Fixes issue pytorch#86807 by adding MPS backend support for aten::hardswish.

## How?
Registered mps hardswish functions in native_functions.yaml, and added the code implementation to Activations.mm.

Added functions:
- hardswish_mps
- hardswish_mps_
- hardswish_backward_mps
- hardswish_out_mps

## Testing
Added test in test/test_mps.py and tested code using the command `python3 test/test_mps.py -k test_hardswish`

Pull Request resolved: pytorch#87952
Approved by: https://github.com/kulinseth, https://github.com/kit1980
  • Loading branch information
thomaslin2020 authored and pytorchmergebot committed Nov 23, 2022
1 parent 1cfd385 commit 4935b59
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 62 deletions.
252 changes: 252 additions & 0 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -2202,5 +2202,257 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {
return grad_input;
}

Tensor& hardswish_out_mps(const Tensor& self, Tensor& output) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;

TORCH_CHECK(self.is_mps());

if (output.numel() == 0) {
return output;
}

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

MPSStream* stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = "hardswish_out_mps" + getTensorsStringKey({self});
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph =
cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, self);

MPSGraphTensor* zeroTensor = [mpsGraph
constantWithScalar:0.0f
shape:@[ @1 ]
dataType:getMPSDataType(self.scalar_type())];

MPSGraphTensor* threeTensor = [mpsGraph
constantWithScalar:3.0f
shape:@[ @1 ]
dataType:getMPSDataType(self.scalar_type())];

MPSGraphTensor* negativeThreeTensor = [mpsGraph
constantWithScalar:-3.0f
shape:@[ @1 ]
dataType:getMPSDataType(self.scalar_type())];

MPSGraphTensor* sixTensor = [mpsGraph
constantWithScalar:6.0f
shape:@[ @1 ]
dataType:getMPSDataType(self.scalar_type())];

MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph
lessThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:negativeThreeTensor
name:nil];

MPSGraphTensor* lessThanMaxPredicateTensor =
[mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* inputPlusThreeTensor =
[mpsGraph additionWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* inputDivSixTensor =
[mpsGraph divisionWithPrimaryTensor:inputPlusThreeTensor
secondaryTensor:sixTensor
name:nil];

MPSGraphTensor* weightedTensor =
[mpsGraph multiplicationWithPrimaryTensor:inputTensor
secondaryTensor:inputDivSixTensor
name:nil];

MPSGraphTensor* tempTensor =
[mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor
truePredicateTensor:weightedTensor
falsePredicateTensor:inputTensor
name:nil];

MPSGraphTensor* outputTensor =
[mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor
truePredicateTensor:zeroTensor
falsePredicateTensor:tempTensor
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, output);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() :
selfPlaceholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() :
outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return output;
}

Tensor hardswish_mps(const Tensor& self) {
using namespace mps;
Tensor output = at::empty_like(self, self.suggest_memory_format());

return hardswish_out_mps(self, output);
}

Tensor& hardswish_mps_(Tensor& self) {
using namespace mps;
Tensor& output = self;

return hardswish_out_mps(self, output);
}

Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) {
using namespace mps;

if (grad_output.numel() == 0) {
return grad_output;
}

Tensor grad_input = at::empty_like(self, self.suggest_memory_format());

struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* gradInputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

MPSStream* stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = "hardswish_backward_mps" + getTensorsStringKey({self});
CachedGraph* cachedGraph = static_cast<CachedGraph*>(cache_->LookUp(key));
if (!cachedGraph) {
MPSCachedGraph* tmpCachedGraph =
cache_->CreateCachedGraph(key, ^MPSCachedGraph*() {
CachedGraph* newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, self);

MPSGraphTensor* zeroTensor = [mpsGraph
constantWithScalar:0.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* unitTensor = [mpsGraph
constantWithScalar:1.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* threeTensor = [mpsGraph
constantWithScalar:3.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* negativeThreeTensor = [mpsGraph
constantWithScalar:-3.0f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* halfTensor = [mpsGraph
constantWithScalar:0.5f
shape:@[ @1 ]
dataType:getMPSDataType(grad_output.scalar_type())];

MPSGraphTensor* tempTensor =
[mpsGraph divisionWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* weightedTensor =
[mpsGraph additionWithPrimaryTensor:tempTensor
secondaryTensor:halfTensor
name:nil];

MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph
lessThanOrEqualToWithPrimaryTensor:inputTensor
secondaryTensor:negativeThreeTensor
name:nil];

MPSGraphTensor* lessThanMaxPredicateTensor =
[mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* lessThanMaxGradTensor =
[mpsGraph selectWithPredicateTensor:lessThanMaxPredicateTensor
truePredicateTensor:weightedTensor
falsePredicateTensor:unitTensor
name:nil];

MPSGraphTensor* gradTensor =
[mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor
truePredicateTensor:zeroTensor
falsePredicateTensor:lessThanMaxGradTensor
name:nil];
MPSGraphTensor* gradInputTensor =
[mpsGraph multiplicationWithPrimaryTensor:gradTensor
secondaryTensor:gradOutputTensor
name:nil];

newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->gradInputTensor_ = gradInputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
}

Placeholder gradOutputPlaceholder =
Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradInputPlaceholder =
Placeholder(cachedGraph->gradInputTensor_, grad_input);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
gradOutputPlaceholder.getMPSGraphTensor() :
gradOutputPlaceholder.getMPSGraphTensorData(),
selfPlaceholder.getMPSGraphTensor() :
selfPlaceholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
gradInputPlaceholder.getMPSGraphTensor() :
gradInputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
return grad_input;
}
} // namespace native
} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10393,23 +10393,27 @@
python_module: nn
dispatch:
CPU, CUDA: hardswish_out
MPS: hardswish_out_mps

- func: hardswish(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: hardswish
MPS: hardswish_mps

- func: hardswish_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: hardswish_
MPS: hardswish_mps_

- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
python_module: nn
dispatch:
CPU, CUDA: hardswish_backward
MPS: hardswish_backward_mps
autogen: hardswish_backward.out

- func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
Loading

0 comments on commit 4935b59

Please sign in to comment.