From 3951457ca58f9f22a39c6e28a716c8a274361a86 Mon Sep 17 00:00:00 2001 From: James Reed Date: Tue, 18 Aug 2020 14:59:20 -0700 Subject: [PATCH] [FX] Add in resnet + quantization tests (#43157) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43157 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D23173327 Pulled By: jamesr66a fbshipit-source-id: 724d0f5399d389cdaa53917861b2113c33b9b5f9 --- test/fx/quantization.py | 324 ++++++++++++++++++++++++++++++++++++++++ test/test_fx.py | 43 +++++- 2 files changed, 365 insertions(+), 2 deletions(-) create mode 100644 test/fx/quantization.py diff --git a/test/fx/quantization.py b/test/fx/quantization.py new file mode 100644 index 00000000000000..c32bac4017b70d --- /dev/null +++ b/test/fx/quantization.py @@ -0,0 +1,324 @@ +r''' +**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not +rely on it for anything!** +''' +from torch.fx import Graph, GraphModule +from torch.fx.graph import map_arg +from torch.fx.proxy import Proxy +from torch.fx.symbolic_trace import DelegateBase +import sys +import torch +from torch.nn.utils import fuse_conv_bn_weights +import operator + +# can be a +# module type, a builtin function, or a string to match target + +def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps): + min_val = min(0.0, min_val) + max_val = max(0.0, max_val) + if max_val == min_val: + return 1.0, 0 + else: + scale = (max_val - min_val) / float(qmax - qmin) + scale = max(scale, eps) + zero_point = qmin - round(min_val / scale) + zero_point = max(qmin, zero_point) + zero_point = min(qmax, zero_point) + zero_point = int(zero_point) + return scale, zero_point + +class MinMaxObserver: + def __init__(self, quantizer, node): + self.min, self.max = float('inf'), float('-inf') + self.all_tensors = True + + def observe(self, node, env): + v = env[node.name] + if not isinstance(v, torch.Tensor): + self.all_tensors = False + return + self.max = max(self.max, float(v.max())) + self.min = min(self.min, float(v.min())) + + def scale_zeropoint(self): + return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) + +class NoObserver: + def __init__(self, quantizer, node): + pass + + def observe(self, node, env): + pass + +DEFAULT_QUANTIZATION_PATTERNS = {} +def register_pattern(pattern): + def insert(fn): + DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn + return fn + return insert + + +@register_pattern(operator.add) +class Add(MinMaxObserver): + def quantize(self, quantizer, node, load_arg): + if not self.all_tensors: + return NotImplemented + scale, zeropoint = self.scale_zeropoint() + return quantizer.quantized_graph.create_node( + 'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint}) + + +class Relu(NoObserver): + def quantize(self, quantizer, node, load_arg): + return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors? + +# these ops have quantized equivalents that do not need any extra information +@register_pattern(torch.nn.ReLU) +@register_pattern(torch.nn.AvgPool2d) +@register_pattern(torch.nn.MaxPool2d) +@register_pattern(torch.nn.AdaptiveAvgPool2d) +class CopyNode(NoObserver): + def quantize(self, quantizer, node, load_arg): + return quantizer.quantized_graph.node_copy(node, load_arg) + +class IdentityModule(torch.nn.Module): + def forward(self, x): + return x + +# handle conv, maybe followed by bn, maybe followed by relu +@register_pattern(torch.nn.modules.conv.Conv2d) +@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) +@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)) +@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))) +class ConvNormRelu(MinMaxObserver): + def __init__(self, quantizer, node): + super().__init__(quantizer, node) + self.relu_node, self.bn_node = None, None + if isinstance(quantizer.modules[node.target], torch.nn.ReLU): + self.relu_node = node + node = node.args[0] + if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): + self.bn_node = node + self.bn = quantizer.modules[self.bn_node.target] + node = node.args[0] + assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) + self.conv_node = node + self.conv = quantizer.modules[self.conv_node.target] + + def quantize(self, quantizer, node, load_arg): + mod = self.conv + weight, bias = mod.weight, mod.bias + + if self.bn_node is not None: + weight, bias = fuse_conv_bn_weights( + weight, bias, self.bn.running_mean, self.bn.running_var, + self.bn.eps, self.bn.weight, self.bn.bias) + + min_val, max_val = float(weight.min()), float(weight.max()) + + act_scale, act_zp = self.scale_zeropoint() + + weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) + qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8) + + ctor = torch.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.nn.quantized.Conv2d + + qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size, + mod.stride, mod.padding, mod.dilation, mod.groups, + mod.bias is not None, mod.padding_mode) + + qconv.set_weight_bias(qweight, bias) + qconv.scale = float(act_scale) + qconv.zero_point = int(act_zp) + parent_name, name = _parent_name(self.conv_node.target) + setattr(quantizer.modules[parent_name], name, qconv) + if self.bn_node is not None: + parent_bn, bn_name = _parent_name(self.bn_node.target) + # we can't just delete this because submodules's forwards (which are not longer use) + # try to call it, so replace with something that does nothing. + setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) + + return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {}) + + +# turn foo.bar -> ['foo', 'bar'] +def _parent_name(target): + r = target.rsplit('.', 1) + if len(r) == 1: + return '', r[0] + else: + return r[0], r[1] + + + +class DefaultQuant(MinMaxObserver): + def quantize(self, input): + assert self.all_tensors + scale, zeropoint = self.scale_zeropoint() + return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node + +def matches(modules, node, pattern, max_uses=sys.maxsize): + if isinstance(pattern, tuple): + self_match, *arg_matches = pattern + else: + self_match = pattern + arg_matches = None + + if node.uses > max_uses: + return False + + if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): + if node.op != 'call_module': + return False + if not isinstance(modules[node.target], self_match): + return False + elif callable(self_match): + if node.op != 'call_function' or node.target is not self_match: + return False + elif node.target != self_match: + return False + + if not arg_matches: + return True + + if len(arg_matches) != len(node.args): + return False + + return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) + + +class Quantizer: + def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): + self.root = mod.root + self.graph = mod.graph + self.quant_ctor = quant_ctor + + # cached information for observe + self.state_dict = self.root.state_dict() + self.modules = dict(self.root.named_modules()) + + # match the patterns that will get quantized + self.matches = self._find_matches(patterns) + # find _inputs_ to matched nodes that are not quantized, these + # have to be quantized, which requires measuring stats, + # initialize an quant_ctor object for each + self.quants = self._find_quants(quant_ctor) + + + + def observe(self, args): + # most of this function is just an interpreter for the graph + # it would be possible to put this in some abstraction, but + # it is pretty nice to just be able to see exactly what is happening here + # and hack on it. + # maybe we should just provide an example interpreter that people copy/paste + # then edit. + args_iter = iter(args) + env = {} + + def load_arg(a): + return map_arg(a, lambda node: env[node.name]) + + for node in self.graph.nodes: + if node.op == 'placeholder': + result = next(args_iter) + elif node.op == 'get_param': + result = self.state_dict[node.target] + elif node.op == 'call_function': + result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) + elif node.op == 'call_method': + self_obj, *args = load_arg(node.args) + kwargs = load_arg(node.kwargs) + result = getattr(self_obj, node.target)(*args, **kwargs) + elif node.op == 'call_module': + result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) + + env[node.name] = result + root_node, obj = self.matches.get(node.name, (None, None)) + if root_node is node: + obj.observe(node, env) + if node.name in self.quants: + self.quants[node.name].observe(node, env) + + return load_arg(self.graph.result) + + def quantize(self): + self.quantized_graph = Graph() + self.delegate = DelegateBase(self.quantized_graph) + + env = {} + quant_env = {} + + def load_arg(n, quantized): + if not quantized: + if n.name not in env and n.name in quant_env: + env[n.name] = Proxy(quant_env[n.name]).dequantize().node + return env[n.name] + else: + if n.name not in quant_env and n.name in env: + quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) + return quant_env[n.name] + + def copy_recursive(node): + def load_or_emit(n): + if n.name in env or e.name in quant_env: + return load_arg(n, quantized=False) + else: + return copy_recusive(n) + r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) + return r + + for node in self.graph.nodes: + root_node, obj = self.matches.get(node.name, (None, None)) + if root_node is None: + # not quantized just copy it + env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) + + elif root_node is node: + r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True))) + if r is NotImplemented: + # quantizer choose to to quantize the node take the entire match, and just copy it over + env[node.name] = copy_recursive(node) + else: + quant_env[node.name] = r + + self.quantized_graph.output(load_arg(self.graph.result, quantized=False)) + return GraphModule(self.root, self.quantized_graph) + + def _find_matches(self, patterns): + modules = dict(self.root.named_modules()) + match_map = {} # node name -> (root_node, match_value?) + + def apply_match(pattern, node, match): + if isinstance(pattern, tuple): + s, *args = pattern + apply_match(s, node, match) + for subpattern, arg in zip(args, node.args): + apply_match(subpattern, arg, match) + else: + match_map[node.name] = match + + for node in reversed(self.graph.nodes): + if node.name not in match_map: + for pattern, value in patterns.items(): + if matches(modules, node, pattern): + apply_match(pattern, node, (node, value(self, node))) + + return match_map + + def _find_quants(self, quant_ctor): + quants = {} + + def visit_arg(n): + # note: we have to measure quantization information + # even for nodes where we might not use it because it is already + # quantized. This is because each match has the option to + # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) + if n.name not in quants: + quants[n.name] = quant_ctor(self, n) + for node in self.graph.nodes: + if node.name in self.matches: + map_arg(node.args, visit_arg) + map_arg(node.kwargs, visit_arg) + return quants diff --git a/test/test_fx.py b/test/test_fx.py index d31e42aad09df0..e3f7c3eebe55f0 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1,10 +1,19 @@ import torch +import unittest from torch.fx import symbolic_trace, Proxy, Node, GraphModule, DefaultDelegate -from typing import Any, Callable, Dict, Optional, Tuple, Union +from fx.quantization import Quantizer +from typing import Any, Callable, Dict, Optional, Tuple, Union from torch.testing._internal.common_utils import TestCase, run_tests +try: + from torchvision.models import resnet18 + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + class TestFX(TestCase): def test_graph_module(self): class MySub(torch.nn.Module): @@ -141,12 +150,42 @@ def forward(self, a, b): return a + b m = M() g = symbolic_trace(m).graph - t = Proxy(g.result) + t = Proxy(g.result) # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. g.output((t + t).node) gm = GraphModule(m, g) self.assertEqual(gm(3, 4), 14) + @skipIfNoTorchVision + def test_resnet(self): + resnet = resnet18() + resnet.train() + + res_graph = symbolic_trace(resnet) + res_script = torch.jit.script(res_graph) + + ip = torch.rand(1, 3, 224, 224) + + a = resnet(ip) + b = res_graph(ip) + c = res_script(ip) + assert torch.allclose(a, b) + assert torch.allclose(a, c) + + quantizer = Quantizer(res_graph) + + for i in range(10): + quantizer.observe((torch.rand(1, 3, 224, 224),)) + + qgraph = quantizer.quantize() + qgraph_script = torch.jit.script(qgraph) + + d = qgraph(ip) + e = qgraph_script(ip) + + assert (a - d).abs().max() < 2 + assert torch.allclose(d, e) + if __name__ == '__main__': run_tests()