Skip to content

Commit 495070b

Browse files
xta0facebook-github-bot
authored andcommitted
[Metal] Add the Python binding for optimize_for_mobile (pytorch#46456)
Summary: Pull Request resolved: pytorch#46456 Add the python binding in CMake. The general workflow is - Build pytorch - `USE_PYTORCH_METAL=ON python setup.py install --cmake` - Run optimize_for_mobile ``` import torch from torch.utils.mobile_optimizer import optimize_for_mobile scripted_model = torch.jit.load('./mobilenetv2.pt') optimized_model = optimize_for_mobile(scripted_model, backend='metal') torch.jit.export_opnames(optimized_model) torch.jit.save(optimized_model, './mobilenetv2_metal.bc') ``` The exported ops are ``` ['aten::adaptive_avg_pool2d', 'aten::add.Tensor', 'aten::addmm', 'aten::reshape', 'aten::size.int', 'metal::copy_to_host', 'metal_prepack::conv2d_run'] ``` ghstack-source-id: 114559878 Test Plan: - Sandcastle CI - Circle CI Reviewed By: kimishpatel Differential Revision: D24356768 fbshipit-source-id: fb5c4c4b6316347b67edb4132da044a81470ddfd
1 parent e8ff0f6 commit 495070b

File tree

6 files changed

+212
-6
lines changed

6 files changed

+212
-6
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,10 @@ if(USE_VULKAN_RELAXED_PRECISION)
548548
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
549549
endif()
550550

551+
if(USE_PYTORCH_METAL)
552+
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL")
553+
endif()
554+
551555
# ---[ Allowlist file if allowlist is specified
552556
include(cmake/Allowlist.cmake)
553557

aten/src/ATen/CMakeLists.txt

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,15 @@ file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
6767
file(GLOB vulkan_cpp "vulkan/*.cpp")
6868
file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp")
6969

70+
# Metal
7071
file(GLOB metal_h "metal/*.h")
7172
file(GLOB metal_cpp "metal/*.cpp")
7273
file(GLOB_RECURSE native_metal_h "native/metal/*.h")
7374
file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm")
7475
file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp")
7576
EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs})
77+
file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h")
78+
file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp")
7679

7780
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
7881
file(GLOB native_quantized_cpp
@@ -125,8 +128,14 @@ else()
125128
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
126129
endif()
127130

131+
# Metal
128132
if(USE_PYTORCH_METAL)
129-
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs})
133+
if(IOS)
134+
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs})
135+
else()
136+
# Add files needed from optimized_for_mobile
137+
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${metal_prepack_cpp})
138+
endif()
130139
else()
131140
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp})
132141
endif()
@@ -391,7 +400,11 @@ if(NOT INTERN_BUILD_MOBILE)
391400
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h})
392401
else()
393402
if(USE_PYTORCH_METAL)
394-
list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h})
403+
if(IOS)
404+
list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h})
405+
else()
406+
list(APPEND INSTALL_HEADERS ${metal_h} ${metal_prepack_h})
407+
endif()
395408
endif()
396409
endif()
397410

test/test_metal.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import torch
2+
from torch.nn import functional as F
3+
4+
from torch.testing._internal.common_utils import TestCase, run_tests
5+
from torch.testing import FileCheck
6+
import io
7+
8+
class TestMetalRewritePass(TestCase):
9+
@staticmethod
10+
def validate_transformed_module(
11+
# To please flake
12+
self,
13+
pattern_count_map,
14+
data_shape,
15+
prepack_removal=False,
16+
fuse_clamping_ops=False):
17+
module_instance = self
18+
scripted_model = torch.jit.script(module_instance)
19+
scripted_model.eval()
20+
input_data = torch.normal(1, 20, size=data_shape)
21+
ref_result = scripted_model(input_data)
22+
torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c)
23+
if fuse_clamping_ops or prepack_removal:
24+
scripted_model._c = torch._C._freeze_module(scripted_model._c)
25+
if fuse_clamping_ops:
26+
torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c)
27+
if prepack_removal:
28+
torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c)
29+
30+
buffer = io.BytesIO()
31+
torch.jit.save(scripted_model, buffer)
32+
buffer.seek(0)
33+
deserialized_scripted_model = torch.jit.load(buffer)
34+
for pattern, v in pattern_count_map.items():
35+
if (v == 0):
36+
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
37+
elif (v == -1):
38+
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
39+
else:
40+
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
41+
42+
def test_conv(self):
43+
# Conv params
44+
batch_size = 2
45+
input_channels_per_group = 6
46+
height = 16
47+
width = 16
48+
output_channels_per_group = 6
49+
groups = 4
50+
kernel_h = kernel_w = 3
51+
stride_h = stride_w = 1
52+
pad_h = pad_w = 1
53+
dilation = 1
54+
input_channels = input_channels_per_group * groups
55+
output_channels = output_channels_per_group * groups
56+
kernels = (kernel_h, kernel_w)
57+
strides = (stride_h, stride_w)
58+
paddings = (pad_h, pad_w)
59+
dilations = (dilation, dilation)
60+
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
61+
conv_bias_shape = (output_channels)
62+
63+
class Conv2D(torch.nn.Module):
64+
def __init__(self):
65+
super(Conv2D, self).__init__()
66+
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
67+
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
68+
self.strides = strides
69+
self.paddings = paddings
70+
self.dilations = dilations
71+
self.groups = groups
72+
73+
def forward(self, x):
74+
return F.conv2d(x, self.weight, self.bias,
75+
self.strides, self.paddings, self.dilations, self.groups)
76+
77+
data_shape = (batch_size, input_channels, height, width)
78+
pattern_count_map = {"Tensor = aten::conv2d": -1,
79+
"metal_prepack::conv2d_prepack": 1,
80+
"metal_prepack::conv2d_run": 1}
81+
TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
82+
83+
class Conv2DRelu(torch.nn.Module):
84+
def __init__(self):
85+
super(Conv2DRelu, self).__init__()
86+
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
87+
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
88+
self.strides = strides
89+
self.paddings = paddings
90+
self.dilations = dilations
91+
self.groups = groups
92+
93+
def forward(self, x):
94+
o = F.conv2d(x, self.weight, self.bias,
95+
self.strides, self.paddings, self.dilations, self.groups)
96+
o = F.relu(o)
97+
return o
98+
99+
data_shape = (batch_size, input_channels, height, width)
100+
pattern_count_map = {"Tensor = aten::conv2d": -1,
101+
"metal_prepack::conv2d_prepack": 1,
102+
"metal_prepack::conv2d_run": 1}
103+
TestMetalRewritePass.validate_transformed_module(
104+
Conv2DRelu(), pattern_count_map, data_shape)
105+
106+
pattern_count_map["aten::relu"] = 1
107+
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
108+
TestMetalRewritePass.validate_transformed_module(
109+
Conv2DRelu(),
110+
pattern_count_map,
111+
data_shape,
112+
prepack_removal=True)
113+
pattern_count_map["aten::relu"] = -1
114+
TestMetalRewritePass.validate_transformed_module(
115+
Conv2DRelu(),
116+
pattern_count_map,
117+
data_shape,
118+
prepack_removal=True,
119+
fuse_clamping_ops=True)
120+
121+
122+
class Conv2DHardtanh(torch.nn.Module):
123+
def __init__(self):
124+
super(Conv2DHardtanh, self).__init__()
125+
self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False)
126+
self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False)
127+
self.strides = strides
128+
self.paddings = paddings
129+
self.dilations = dilations
130+
self.groups = groups
131+
132+
def forward(self, x):
133+
o = F.conv2d(x, self.weight, self.bias,
134+
self.strides, self.paddings, self.dilations, self.groups)
135+
o = F.hardtanh(o)
136+
return o
137+
138+
data_shape = (batch_size, input_channels, height, width)
139+
pattern_count_map = {"Tensor = aten::conv2d": -1,
140+
"metal_prepack::conv2d_prepack": 1,
141+
"metal_prepack::conv2d_run": 1}
142+
TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
143+
pattern_count_map["aten::hardtanh"] = 1
144+
pattern_count_map["metal_prepack::conv2d_prepack"] = -1
145+
TestMetalRewritePass.validate_transformed_module(
146+
Conv2DHardtanh(),
147+
pattern_count_map,
148+
data_shape,
149+
prepack_removal=True)
150+
pattern_count_map["aten::hardtanh"] = -1
151+
TestMetalRewritePass.validate_transformed_module(
152+
Conv2DRelu(),
153+
pattern_count_map,
154+
data_shape,
155+
prepack_removal=True,
156+
fuse_clamping_ops=True)
157+
158+
if __name__ == "__main__":
159+
run_tests()

torch/_C/__init__.pyi.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
176176
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
177177
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
178178
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
179+
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
180+
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
179181
def _jit_pass_inline(Graph) -> None: ...
180182
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
181183
def _jit_can_fuse_on_cpu() -> _bool: ...

torch/csrc/jit/python/init.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <torch/csrc/jit/passes/loop_unrolling.h>
3030
#include <torch/csrc/jit/passes/lower_graph.h>
3131
#include <torch/csrc/jit/passes/lower_tuples.h>
32+
#include <torch/csrc/jit/passes/metal_rewrite.h>
3233
#include <torch/csrc/jit/passes/normalize_ops.h>
3334
#include <torch/csrc/jit/passes/onnx.h>
3435
#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
@@ -679,6 +680,30 @@ void initJITBindings(PyObject* module) {
679680
std::vector<std::string>& preserved_methods) {
680681
return vulkanOptimizeForMobile(module, preserved_methods);
681682
})
683+
.def(
684+
"_jit_pass_metal_insert_prepacked_ops",
685+
[](std::shared_ptr<Graph>& graph) {
686+
return metalInsertPrePackedOps(graph);
687+
})
688+
.def(
689+
"_jit_pass_metal_insert_prepacked_ops",
690+
[](script::Module& module) {
691+
return metalInsertPrePackedOps(module);
692+
})
693+
.def(
694+
"_jit_pass_metal_fuse_clamp_w_prepacked_conv",
695+
[](script::Module& module) {
696+
return metalFusePrePackedConvWithClamp(module);
697+
})
698+
.def(
699+
"_jit_pass_metal_fold_prepacking_ops",
700+
[](script::Module& module) { return metalFoldPrePackingOps(module); })
701+
.def(
702+
"_jit_pass_metal_optimize_for_mobile",
703+
[](script::Module& module,
704+
std::vector<std::string>& preserved_methods) {
705+
return metalOptimizeForMobile(module, preserved_methods);
706+
})
682707
.def(
683708
"_jit_pass_onnx_unpack_quantized_weights",
684709
[](std::shared_ptr<Graph>& graph,

torch/utils/mobile_optimizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def optimize_for_mobile(
2525
optimization method will run all the optimizer pass; otherwise, optimizer
2626
method will run the optimization pass that is not included inside optimization_blocklist.
2727
perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked
28-
backend: Device type to use for running the result model ('CPU'(default) or 'Vulkan').
28+
backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal').
2929
Returns:
3030
A new optimized torch script module
3131
"""
@@ -39,12 +39,15 @@ def optimize_for_mobile(
3939
if preserved_methods is None:
4040
preserved_methods = []
4141

42-
if backend == 'CPU':
42+
backend = backend.lower()
43+
if backend == 'cpu':
4344
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)
44-
elif backend == 'Vulkan':
45+
elif backend == 'vulkan':
4546
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods)
47+
elif backend == 'metal':
48+
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods)
4649
else:
47-
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan'")
50+
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")
4851

4952
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
5053

0 commit comments

Comments
 (0)