Skip to content

Commit

Permalink
[FX] Add in resnet + quantization tests (pytorch#43157)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#43157

Test Plan: Imported from OSS

Reviewed By: zdevito

Differential Revision: D23173327

Pulled By: jamesr66a

fbshipit-source-id: 724d0f5399d389cdaa53917861b2113c33b9b5f9
  • Loading branch information
James Reed authored and facebook-github-bot committed Aug 18, 2020
1 parent dd194c1 commit 3951457
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 2 deletions.
324 changes: 324 additions & 0 deletions test/fx/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
r'''
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
rely on it for anything!**
'''
from torch.fx import Graph, GraphModule
from torch.fx.graph import map_arg
from torch.fx.proxy import Proxy
from torch.fx.symbolic_trace import DelegateBase
import sys
import torch
from torch.nn.utils import fuse_conv_bn_weights
import operator

# can be a
# module type, a builtin function, or a string to match target

def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps):
min_val = min(0.0, min_val)
max_val = max(0.0, max_val)
if max_val == min_val:
return 1.0, 0
else:
scale = (max_val - min_val) / float(qmax - qmin)
scale = max(scale, eps)
zero_point = qmin - round(min_val / scale)
zero_point = max(qmin, zero_point)
zero_point = min(qmax, zero_point)
zero_point = int(zero_point)
return scale, zero_point

class MinMaxObserver:
def __init__(self, quantizer, node):
self.min, self.max = float('inf'), float('-inf')
self.all_tensors = True

def observe(self, node, env):
v = env[node.name]
if not isinstance(v, torch.Tensor):
self.all_tensors = False
return
self.max = max(self.max, float(v.max()))
self.min = min(self.min, float(v.min()))

def scale_zeropoint(self):
return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255)

class NoObserver:
def __init__(self, quantizer, node):
pass

def observe(self, node, env):
pass

DEFAULT_QUANTIZATION_PATTERNS = {}
def register_pattern(pattern):
def insert(fn):
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert


@register_pattern(operator.add)
class Add(MinMaxObserver):
def quantize(self, quantizer, node, load_arg):
if not self.all_tensors:
return NotImplemented
scale, zeropoint = self.scale_zeropoint()
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint})


class Relu(NoObserver):
def quantize(self, quantizer, node, load_arg):
return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors?

# these ops have quantized equivalents that do not need any extra information
@register_pattern(torch.nn.ReLU)
@register_pattern(torch.nn.AvgPool2d)
@register_pattern(torch.nn.MaxPool2d)
@register_pattern(torch.nn.AdaptiveAvgPool2d)
class CopyNode(NoObserver):
def quantize(self, quantizer, node, load_arg):
return quantizer.quantized_graph.node_copy(node, load_arg)

class IdentityModule(torch.nn.Module):
def forward(self, x):
return x

# handle conv, maybe followed by bn, maybe followed by relu
@register_pattern(torch.nn.modules.conv.Conv2d)
@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d))
@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))
@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)))
class ConvNormRelu(MinMaxObserver):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node, self.bn_node = None, None
if isinstance(quantizer.modules[node.target], torch.nn.ReLU):
self.relu_node = node
node = node.args[0]
if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d):
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
node = node.args[0]
assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d)
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]

def quantize(self, quantizer, node, load_arg):
mod = self.conv
weight, bias = mod.weight, mod.bias

if self.bn_node is not None:
weight, bias = fuse_conv_bn_weights(
weight, bias, self.bn.running_mean, self.bn.running_var,
self.bn.eps, self.bn.weight, self.bn.bias)

min_val, max_val = float(weight.min()), float(weight.max())

act_scale, act_zp = self.scale_zeropoint()

weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val)
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8)

ctor = torch.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.nn.quantized.Conv2d

qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size,
mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode)

qconv.set_weight_bias(qweight, bias)
qconv.scale = float(act_scale)
qconv.zero_point = int(act_zp)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, qconv)
if self.bn_node is not None:
parent_bn, bn_name = _parent_name(self.bn_node.target)
# we can't just delete this because submodules's forwards (which are not longer use)
# try to call it, so replace with something that does nothing.
setattr(quantizer.modules[parent_name], bn_name, IdentityModule())

return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {})


# turn foo.bar -> ['foo', 'bar']
def _parent_name(target):
r = target.rsplit('.', 1)
if len(r) == 1:
return '', r[0]
else:
return r[0], r[1]



class DefaultQuant(MinMaxObserver):
def quantize(self, input):
assert self.all_tensors
scale, zeropoint = self.scale_zeropoint()
return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node

def matches(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
else:
self_match = pattern
arg_matches = None

if node.uses > max_uses:
return False

if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not isinstance(modules[node.target], self_match):
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target != self_match:
return False

if not arg_matches:
return True

if len(arg_matches) != len(node.args):
return False

return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))


class Quantizer:
def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant):
self.root = mod.root
self.graph = mod.graph
self.quant_ctor = quant_ctor

# cached information for observe
self.state_dict = self.root.state_dict()
self.modules = dict(self.root.named_modules())

# match the patterns that will get quantized
self.matches = self._find_matches(patterns)
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an quant_ctor object for each
self.quants = self._find_quants(quant_ctor)



def observe(self, args):
# most of this function is just an interpreter for the graph
# it would be possible to put this in some abstraction, but
# it is pretty nice to just be able to see exactly what is happening here
# and hack on it.
# maybe we should just provide an example interpreter that people copy/paste
# then edit.
args_iter = iter(args)
env = {}

def load_arg(a):
return map_arg(a, lambda node: env[node.name])

for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_param':
result = self.state_dict[node.target]
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

env[node.name] = result
root_node, obj = self.matches.get(node.name, (None, None))
if root_node is node:
obj.observe(node, env)
if node.name in self.quants:
self.quants[node.name].observe(node, env)

return load_arg(self.graph.result)

def quantize(self):
self.quantized_graph = Graph()
self.delegate = DelegateBase(self.quantized_graph)

env = {}
quant_env = {}

def load_arg(n, quantized):
if not quantized:
if n.name not in env and n.name in quant_env:
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
return env[n.name]
else:
if n.name not in quant_env and n.name in env:
quant_env[n.name] = self.quants[n.name].quantize(env[n.name])
return quant_env[n.name]

def copy_recursive(node):
def load_or_emit(n):
if n.name in env or e.name in quant_env:
return load_arg(n, quantized=False)
else:
return copy_recusive(n)
r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False))
return r

for node in self.graph.nodes:
root_node, obj = self.matches.get(node.name, (None, None))
if root_node is None:
# not quantized just copy it
env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False))

elif root_node is node:
r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)))
if r is NotImplemented:
# quantizer choose to to quantize the node take the entire match, and just copy it over
env[node.name] = copy_recursive(node)
else:
quant_env[node.name] = r

self.quantized_graph.output(load_arg(self.graph.result, quantized=False))
return GraphModule(self.root, self.quantized_graph)

def _find_matches(self, patterns):
modules = dict(self.root.named_modules())
match_map = {} # node name -> (root_node, match_value?)

def apply_match(pattern, node, match):
if isinstance(pattern, tuple):
s, *args = pattern
apply_match(s, node, match)
for subpattern, arg in zip(args, node.args):
apply_match(subpattern, arg, match)
else:
match_map[node.name] = match

for node in reversed(self.graph.nodes):
if node.name not in match_map:
for pattern, value in patterns.items():
if matches(modules, node, pattern):
apply_match(pattern, node, (node, value(self, node)))

return match_map

def _find_quants(self, quant_ctor):
quants = {}

def visit_arg(n):
# note: we have to measure quantization information
# even for nodes where we might not use it because it is already
# quantized. This is because each match has the option to
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
if n.name not in quants:
quants[n.name] = quant_ctor(self, n)
for node in self.graph.nodes:
if node.name in self.matches:
map_arg(node.args, visit_arg)
map_arg(node.kwargs, visit_arg)
return quants
43 changes: 41 additions & 2 deletions test/test_fx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import torch
import unittest
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, DefaultDelegate

from typing import Any, Callable, Dict, Optional, Tuple, Union
from fx.quantization import Quantizer

from typing import Any, Callable, Dict, Optional, Tuple, Union
from torch.testing._internal.common_utils import TestCase, run_tests

try:
from torchvision.models import resnet18
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

class TestFX(TestCase):
def test_graph_module(self):
class MySub(torch.nn.Module):
Expand Down Expand Up @@ -141,12 +150,42 @@ def forward(self, a, b):
return a + b
m = M()
g = symbolic_trace(m).graph
t = Proxy(g.result)
t = Proxy(g.result)
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
g.output((t + t).node)
gm = GraphModule(m, g)
self.assertEqual(gm(3, 4), 14)

@skipIfNoTorchVision
def test_resnet(self):
resnet = resnet18()
resnet.train()

res_graph = symbolic_trace(resnet)
res_script = torch.jit.script(res_graph)

ip = torch.rand(1, 3, 224, 224)

a = resnet(ip)
b = res_graph(ip)
c = res_script(ip)
assert torch.allclose(a, b)
assert torch.allclose(a, c)

quantizer = Quantizer(res_graph)

for i in range(10):
quantizer.observe((torch.rand(1, 3, 224, 224),))

qgraph = quantizer.quantize()
qgraph_script = torch.jit.script(qgraph)

d = qgraph(ip)
e = qgraph_script(ip)

assert (a - d).abs().max() < 2
assert torch.allclose(d, e)


if __name__ == '__main__':
run_tests()

0 comments on commit 3951457

Please sign in to comment.