|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +""" |
| 18 | +Deploy a Framework-prequantized Model with TVM |
| 19 | +============================================== |
| 20 | +**Author**: `Masahiro Masuda <https://github.com/masahi>`_ |
| 21 | +
|
| 22 | +This is a tutorial on loading models quantized by deep learning frameworks into TVM. |
| 23 | +Pre-quantized model import is one of the quantization support we have in TVM. More details on |
| 24 | +the quantization story in TVM can be found |
| 25 | +`here <https://discuss.tvm.ai/t/quantization-story/3920>`_. |
| 26 | +
|
| 27 | +Here, we demonstrate how to load and run models quantized by PyTorch, MXNet, and TFLite. |
| 28 | +Once loaded, we can run compiled, quantized models on any hardware TVM supports. |
| 29 | +""" |
| 30 | + |
| 31 | +################################################################################# |
| 32 | +# First, necessary imports |
| 33 | +from PIL import Image |
| 34 | + |
| 35 | +import numpy as np |
| 36 | + |
| 37 | +import torch |
| 38 | +from torchvision.models.quantization import mobilenet as qmobilenet |
| 39 | + |
| 40 | +import tvm |
| 41 | +from tvm import relay |
| 42 | +from tvm.contrib.download import download_testdata |
| 43 | + |
| 44 | + |
| 45 | +################################################################################# |
| 46 | +# Helper functions to run the demo |
| 47 | +def get_transform(): |
| 48 | + import torchvision.transforms as transforms |
| 49 | + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 50 | + std=[0.229, 0.224, 0.225]) |
| 51 | + return transforms.Compose([ |
| 52 | + transforms.Resize(256), |
| 53 | + transforms.CenterCrop(224), |
| 54 | + transforms.ToTensor(), |
| 55 | + normalize, |
| 56 | + ]) |
| 57 | + |
| 58 | + |
| 59 | +def get_real_image(im_height, im_width): |
| 60 | + img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' |
| 61 | + img_path = download_testdata(img_url, 'cat.png', module='data') |
| 62 | + return Image.open(img_path).resize((im_height, im_width)) |
| 63 | + |
| 64 | + |
| 65 | +def get_imagenet_input(): |
| 66 | + im = get_real_image(224, 224) |
| 67 | + preprocess = get_transform() |
| 68 | + pt_tensor = preprocess(im) |
| 69 | + return np.expand_dims(pt_tensor.numpy(), 0) |
| 70 | + |
| 71 | + |
| 72 | +def get_synset(): |
| 73 | + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', |
| 74 | + '4d0b62f3d01426887599d4f7ede23ee5/raw/', |
| 75 | + '596b27d23537e5a1b5751d2b0481ef172f58b539/', |
| 76 | + 'imagenet1000_clsid_to_human.txt']) |
| 77 | + synset_name = 'imagenet1000_clsid_to_human.txt' |
| 78 | + synset_path = download_testdata(synset_url, synset_name, module='data') |
| 79 | + with open(synset_path) as f: |
| 80 | + return eval(f.read()) |
| 81 | + |
| 82 | + |
| 83 | +def run_tvm_model(mod, params, input_name, inp, target="llvm"): |
| 84 | + with relay.build_config(opt_level=3): |
| 85 | + json, lib, params = relay.build(mod, target=target, params=params) |
| 86 | + |
| 87 | + runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.context(target, 0)) |
| 88 | + runtime.set_input(**params) |
| 89 | + |
| 90 | + runtime.set_input(input_name, inp) |
| 91 | + runtime.run() |
| 92 | + return runtime.get_output(0).asnumpy(), runtime |
| 93 | + |
| 94 | + |
| 95 | +################################################################################# |
| 96 | +# A mapping from label to class name, to verify that the outputs from models below |
| 97 | +# are reasonable |
| 98 | +synset = get_synset() |
| 99 | + |
| 100 | +################################################################################# |
| 101 | +# Everyone's favorite cat image for demonstration |
| 102 | +inp = get_imagenet_input() |
| 103 | + |
| 104 | +################################################################################ |
| 105 | +# Deploy a quantized PyTorch Model |
| 106 | +# -------------------------------- |
| 107 | +# First, we demonstrate how to load deep learning models quantized by PyTorch, |
| 108 | +# using our PyTorch frontend. |
| 109 | +# |
| 110 | +# Please refer to the PyTorch static quantization tutorial below to learn about |
| 111 | +# their quantization workflow. |
| 112 | +# https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html |
| 113 | +# |
| 114 | +# We use this function to quantize PyTorch models. |
| 115 | +# In short, this function takes a floating point model and converts it to uint8. |
| 116 | +# The model is per-channel quantized. |
| 117 | + |
| 118 | +def quantize_model(model, inp): |
| 119 | + model.fuse_model() |
| 120 | + model.qconfig = torch.quantization.get_default_qconfig('fbgemm') |
| 121 | + torch.quantization.prepare(model, inplace=True) |
| 122 | + # Dummy calibration |
| 123 | + model(inp) |
| 124 | + torch.quantization.convert(model, inplace=True) |
| 125 | + |
| 126 | + |
| 127 | +############################################################################## |
| 128 | +# Load quantization-ready, pretrained Mobilenet v2 model from torchvision |
| 129 | +# ----------------------------------------------------------------------- |
| 130 | +# We choose mobilenet v2 because this model was trained with quantization aware |
| 131 | +# training. Other models require a full post training calibration. |
| 132 | +qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval() |
| 133 | + |
| 134 | +############################################################################## |
| 135 | +# Quantize, trace and run the PyTorch Mobilenet v2 model |
| 136 | +# ------------------------------------------------------ |
| 137 | +# The details are out of scope for this tutorial. Please refer to the tutorials |
| 138 | +# on the PyTorch website to learn about quantization and jit. |
| 139 | +pt_inp = torch.from_numpy(inp) |
| 140 | +quantize_model(qmodel, pt_inp) |
| 141 | +script_module = torch.jit.trace(qmodel, pt_inp).eval() |
| 142 | + |
| 143 | +with torch.no_grad(): |
| 144 | + pt_result = script_module(pt_inp).numpy() |
| 145 | + |
| 146 | +############################################################################## |
| 147 | +# Convert quantized Mobilenet v2 to Relay-QNN using the PyTorch frontend |
| 148 | +# ---------------------------------------------------------------------- |
| 149 | +# The PyTorch frontend has support for converting a quantized PyTorch model to |
| 150 | +# an equivalent Relay module enriched with quantization-aware operators. |
| 151 | +# We call this representation Relay QNN dialect. |
| 152 | +# |
| 153 | +# You can print the output from the frontend to see how quantized models are |
| 154 | +# represented. |
| 155 | +# |
| 156 | +# You would see operators specific to quantization such as |
| 157 | +# qnn.quantize, qnn.dequantize, qnn.requantize, and qnn.conv2d etc. |
| 158 | +input_name = "input" # the input name can be be arbitrary for PyTorch frontend. |
| 159 | +input_shapes = [(input_name, (1, 3, 224, 224))] |
| 160 | +mod, params = relay.frontend.from_pytorch(script_module, input_shapes) |
| 161 | +# print(mod) # comment in to see the QNN IR dump |
| 162 | + |
| 163 | +############################################################################## |
| 164 | +# Compile and run the Relay module |
| 165 | +# -------------------------------- |
| 166 | +# Once we obtained the quantized Relay module, the rest of the workflow |
| 167 | +# is the same as running floating point models. Please refer to other |
| 168 | +# tutorials for more details. |
| 169 | +# |
| 170 | +# Under the hood, quantization specific operators are lowered to a sequence of |
| 171 | +# standard Relay operators before compilation. |
| 172 | +tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target="llvm") |
| 173 | + |
| 174 | +########################################################################## |
| 175 | +# Compare the output labels |
| 176 | +# ------------------------- |
| 177 | +# We should see identical labels printed. |
| 178 | +pt_top3_labels = np.argsort(pt_result[0])[::-1][:3] |
| 179 | +tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3] |
| 180 | + |
| 181 | +print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels]) |
| 182 | +print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels]) |
| 183 | + |
| 184 | +########################################################################################### |
| 185 | +# However, due to the difference in numerics, in general the raw floating point |
| 186 | +# outputs are not expected to be identical. Here, we print how many floating point |
| 187 | +# output values are identical out of 1000 outputs from mobilenet v2. |
| 188 | +print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0])) |
| 189 | + |
| 190 | +########################################################################## |
| 191 | +# Measure performance |
| 192 | +# ------------------------- |
| 193 | +# Here we give an example of how to measure performance of TVM compiled models. |
| 194 | +n_repeat = 100 # should be bigger to make the measurement more accurate |
| 195 | +ctx = tvm.cpu(0) |
| 196 | +ftimer = rt_mod.module.time_evaluator("run", ctx, number=1, |
| 197 | + repeat=n_repeat) |
| 198 | +prof_res = np.array(ftimer().results) * 1e3 |
| 199 | +print("Elapsed average ms:", np.mean(prof_res)) |
| 200 | + |
| 201 | +###################################################################### |
| 202 | +# .. note:: |
| 203 | +# |
| 204 | +# We recommend this method for the following reasons: |
| 205 | +# |
| 206 | +# * Measurements are done in C++, so there is no Python overhead |
| 207 | +# * It includes several warm up runs |
| 208 | +# * The same method can be used to profile on remote devices (android etc.). |
| 209 | + |
| 210 | + |
| 211 | +###################################################################### |
| 212 | +# .. note:: |
| 213 | +# |
| 214 | +# Unless the hardware has special support for fast 8 bit instructions, quantized models are |
| 215 | +# not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does |
| 216 | +# quantized convolution in 16 bit, even if the model itself is 8 bit. |
| 217 | +# |
| 218 | +# For x86, the best performance can be achieved on CPUs with AVX512 instructions set. |
| 219 | +# In this case, TVM utilizes the fastest available 8 bit instructions for the given target. |
| 220 | +# This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer). |
| 221 | +# |
| 222 | +# Moreover, the following general tips for CPU performance equally applies: |
| 223 | +# |
| 224 | +# * Set the environment variable TVM_NUM_THREADS to the number of physical cores |
| 225 | +# * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or |
| 226 | +# "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) |
| 227 | + |
| 228 | + |
| 229 | +############################################################################### |
| 230 | +# Deploy a quantized MXNet Model |
| 231 | +# ------------------------------ |
| 232 | +# TODO |
| 233 | + |
| 234 | +############################################################################### |
| 235 | +# Deploy a quantized TFLite Model |
| 236 | +# ------------------------------- |
| 237 | +# TODO |
0 commit comments