Skip to content

Commit 3d18adf

Browse files
authored
[Tutorial, QNN] Add tutorial for loading quantized PyTorch model (#5321)
* add pytorch tutorial code and doc stub * add more docs * formatting, more docs * typo fix * try make sphinx happy * add performance section * type and nit fix * format fix
1 parent 09eb508 commit 3d18adf

File tree

3 files changed

+240
-3
lines changed

3 files changed

+240
-3
lines changed

docs/dev/relay_pass_infra.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ sequential pass example could be like the following to enable IR dumping for
612612
seq = tvm.transform.Sequential([
613613
relay.transform.InferType(),
614614
relay.transform.FoldConstant(),
615-
relay.transform.PrintIR(),
615+
transform.PrintIR(),
616616
relay.transform.EliminateCommonSubexpr(),
617617
relay.transform.AlterOpLayout()
618618
])
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
Deploy a Framework-prequantized Model with TVM
19+
==============================================
20+
**Author**: `Masahiro Masuda <https://github.com/masahi>`_
21+
22+
This is a tutorial on loading models quantized by deep learning frameworks into TVM.
23+
Pre-quantized model import is one of the quantization support we have in TVM. More details on
24+
the quantization story in TVM can be found
25+
`here <https://discuss.tvm.ai/t/quantization-story/3920>`_.
26+
27+
Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite.
28+
Once loaded, we can run compiled, quantized models on any hardware TVM supports.
29+
"""
30+
31+
#################################################################################
32+
# First, necessary imports
33+
from PIL import Image
34+
35+
import numpy as np
36+
37+
import torch
38+
from torchvision.models.quantization import mobilenet as qmobilenet
39+
40+
import tvm
41+
from tvm import relay
42+
from tvm.contrib.download import download_testdata
43+
44+
45+
#################################################################################
46+
# Helper functions to run the demo
47+
def get_transform():
48+
import torchvision.transforms as transforms
49+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
50+
std=[0.229, 0.224, 0.225])
51+
return transforms.Compose([
52+
transforms.Resize(256),
53+
transforms.CenterCrop(224),
54+
transforms.ToTensor(),
55+
normalize,
56+
])
57+
58+
59+
def get_real_image(im_height, im_width):
60+
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
61+
img_path = download_testdata(img_url, 'cat.png', module='data')
62+
return Image.open(img_path).resize((im_height, im_width))
63+
64+
65+
def get_imagenet_input():
66+
im = get_real_image(224, 224)
67+
preprocess = get_transform()
68+
pt_tensor = preprocess(im)
69+
return np.expand_dims(pt_tensor.numpy(), 0)
70+
71+
72+
def get_synset():
73+
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
74+
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
75+
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
76+
'imagenet1000_clsid_to_human.txt'])
77+
synset_name = 'imagenet1000_clsid_to_human.txt'
78+
synset_path = download_testdata(synset_url, synset_name, module='data')
79+
with open(synset_path) as f:
80+
return eval(f.read())
81+
82+
83+
def run_tvm_model(mod, params, input_name, inp, target="llvm"):
84+
with relay.build_config(opt_level=3):
85+
json, lib, params = relay.build(mod, target=target, params=params)
86+
87+
runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.context(target, 0))
88+
runtime.set_input(**params)
89+
90+
runtime.set_input(input_name, inp)
91+
runtime.run()
92+
return runtime.get_output(0).asnumpy(), runtime
93+
94+
95+
#################################################################################
96+
# A mapping from label to class name, to verify that the outputs from models below
97+
# are reasonable
98+
synset = get_synset()
99+
100+
#################################################################################
101+
# Everyone's favorite cat image for demonstration
102+
inp = get_imagenet_input()
103+
104+
################################################################################
105+
# Deploy a quantized PyTorch Model
106+
# --------------------------------
107+
# First, we demonstrate how to load deep learning models quantized by PyTorch,
108+
# using our PyTorch frontend.
109+
#
110+
# Please refer to the PyTorch static quantization tutorial below to learn about
111+
# their quantization workflow.
112+
# https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html
113+
#
114+
# We use this function to quantize PyTorch models.
115+
# In short, this function takes a floating point model and converts it to uint8.
116+
# The model is per-channel quantized.
117+
118+
def quantize_model(model, inp):
119+
model.fuse_model()
120+
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
121+
torch.quantization.prepare(model, inplace=True)
122+
# Dummy calibration
123+
model(inp)
124+
torch.quantization.convert(model, inplace=True)
125+
126+
127+
##############################################################################
128+
# Load quantization-ready, pretrained Mobilenet v2 model from torchvision
129+
# -----------------------------------------------------------------------
130+
# We choose mobilenet v2 because this model was trained with quantization aware
131+
# training. Other models require a full post training calibration.
132+
qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()
133+
134+
##############################################################################
135+
# Quantize, trace and run the PyTorch Mobilenet v2 model
136+
# ------------------------------------------------------
137+
# The details are out of scope for this tutorial. Please refer to the tutorials
138+
# on the PyTorch website to learn about quantization and jit.
139+
pt_inp = torch.from_numpy(inp)
140+
quantize_model(qmodel, pt_inp)
141+
script_module = torch.jit.trace(qmodel, pt_inp).eval()
142+
143+
with torch.no_grad():
144+
pt_result = script_module(pt_inp).numpy()
145+
146+
##############################################################################
147+
# Convert quantized Mobilenet v2 to Relay-QNN using the PyTorch frontend
148+
# ----------------------------------------------------------------------
149+
# The PyTorch frontend has support for converting a quantized PyTorch model to
150+
# an equivalent Relay module enriched with quantization-aware operators.
151+
# We call this representation Relay QNN dialect.
152+
#
153+
# You can print the output from the frontend to see how quantized models are
154+
# represented.
155+
#
156+
# You would see operators specific to quantization such as
157+
# qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc.
158+
input_name = "input" # the input name can be be arbitrary for PyTorch frontend.
159+
input_shapes = [(input_name, (1, 3, 224, 224))]
160+
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
161+
# print(mod) # comment in to see the QNN IR dump
162+
163+
##############################################################################
164+
# Compile and run the Relay module
165+
# --------------------------------
166+
# Once we obtained the quantized Relay module, the rest of the workflow
167+
# is the same as running floating point models. Please refer to other
168+
# tutorials for more details.
169+
#
170+
# Under the hood, quantization specific operators are lowered to a sequence of
171+
# standard Relay operators before compilation.
172+
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target="llvm")
173+
174+
##########################################################################
175+
# Compare the output labels
176+
# -------------------------
177+
# We should see identical labels printed.
178+
pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]
179+
tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3]
180+
181+
print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels])
182+
print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels])
183+
184+
###########################################################################################
185+
# However, due to the difference in numerics, in general the raw floating point
186+
# outputs are not expected to be identical. Here, we print how many floating point
187+
# output values are identical out of 1000 outputs from mobilenet v2.
188+
print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0]))
189+
190+
##########################################################################
191+
# Measure performance
192+
# -------------------------
193+
# Here we give an example of how to measure performance of TVM compiled models.
194+
n_repeat = 100 # should be bigger to make the measurement more accurate
195+
ctx = tvm.cpu(0)
196+
ftimer = rt_mod.module.time_evaluator("run", ctx, number=1,
197+
repeat=n_repeat)
198+
prof_res = np.array(ftimer().results) * 1e3
199+
print("Elapsed average ms:", np.mean(prof_res))
200+
201+
######################################################################
202+
# .. note::
203+
#
204+
# We recommend this method for the following reasons:
205+
#
206+
# * Measurements are done in C++, so there is no Python overhead
207+
# * It includes several warm up runs
208+
# * The same method can be used to profile on remote devices (android etc.).
209+
210+
211+
######################################################################
212+
# .. note::
213+
#
214+
# Unless the hardware has special support for fast 8 bit instructions, quantized models are
215+
# not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does
216+
# quantized convolution in 16 bit, even if the model itself is 8 bit.
217+
#
218+
# For x86, the best performance can be achieved on CPUs with AVX512 instructions set.
219+
# In this case, TVM utilizes the fastest available 8 bit instructions for the given target.
220+
# This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer).
221+
#
222+
# Moreover, the following general tips for CPU performance equally applies:
223+
#
224+
# * Set the environment variable TVM_NUM_THREADS to the number of physical cores
225+
# * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or
226+
# "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future)
227+
228+
229+
###############################################################################
230+
# Deploy a quantized MXNet Model
231+
# ------------------------------
232+
# TODO
233+
234+
###############################################################################
235+
# Deploy a quantized TFLite Model
236+
# -------------------------------
237+
# TODO

tutorials/frontend/from_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@
8888
######################################################################
8989
# Import the graph to Relay
9090
# -------------------------
91-
# Convert PyTorch graph to Relay graph.
92-
input_name = 'input0' # only one input, set it to this name
91+
# Convert PyTorch graph to Relay graph. The input name can be arbitrary.
92+
input_name = 'input0'
9393
shape_list = [(input_name, img.shape)]
9494
mod, params = relay.frontend.from_pytorch(scripted_model,
9595
shape_list)

0 commit comments

Comments
 (0)