forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FX] Add in resnet + quantization tests (pytorch#43157)
Summary: Pull Request resolved: pytorch#43157 Test Plan: Imported from OSS Reviewed By: zdevito Differential Revision: D23173327 Pulled By: jamesr66a fbshipit-source-id: 724d0f5399d389cdaa53917861b2113c33b9b5f9
- Loading branch information
1 parent
dd194c1
commit 3951457
Showing
2 changed files
with
365 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,324 @@ | ||
r''' | ||
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not | ||
rely on it for anything!** | ||
''' | ||
from torch.fx import Graph, GraphModule | ||
from torch.fx.graph import map_arg | ||
from torch.fx.proxy import Proxy | ||
from torch.fx.symbolic_trace import DelegateBase | ||
import sys | ||
import torch | ||
from torch.nn.utils import fuse_conv_bn_weights | ||
import operator | ||
|
||
# can be a | ||
# module type, a builtin function, or a string to match target | ||
|
||
def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps): | ||
min_val = min(0.0, min_val) | ||
max_val = max(0.0, max_val) | ||
if max_val == min_val: | ||
return 1.0, 0 | ||
else: | ||
scale = (max_val - min_val) / float(qmax - qmin) | ||
scale = max(scale, eps) | ||
zero_point = qmin - round(min_val / scale) | ||
zero_point = max(qmin, zero_point) | ||
zero_point = min(qmax, zero_point) | ||
zero_point = int(zero_point) | ||
return scale, zero_point | ||
|
||
class MinMaxObserver: | ||
def __init__(self, quantizer, node): | ||
self.min, self.max = float('inf'), float('-inf') | ||
self.all_tensors = True | ||
|
||
def observe(self, node, env): | ||
v = env[node.name] | ||
if not isinstance(v, torch.Tensor): | ||
self.all_tensors = False | ||
return | ||
self.max = max(self.max, float(v.max())) | ||
self.min = min(self.min, float(v.min())) | ||
|
||
def scale_zeropoint(self): | ||
return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) | ||
|
||
class NoObserver: | ||
def __init__(self, quantizer, node): | ||
pass | ||
|
||
def observe(self, node, env): | ||
pass | ||
|
||
DEFAULT_QUANTIZATION_PATTERNS = {} | ||
def register_pattern(pattern): | ||
def insert(fn): | ||
DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn | ||
return fn | ||
return insert | ||
|
||
|
||
@register_pattern(operator.add) | ||
class Add(MinMaxObserver): | ||
def quantize(self, quantizer, node, load_arg): | ||
if not self.all_tensors: | ||
return NotImplemented | ||
scale, zeropoint = self.scale_zeropoint() | ||
return quantizer.quantized_graph.create_node( | ||
'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint}) | ||
|
||
|
||
class Relu(NoObserver): | ||
def quantize(self, quantizer, node, load_arg): | ||
return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors? | ||
|
||
# these ops have quantized equivalents that do not need any extra information | ||
@register_pattern(torch.nn.ReLU) | ||
@register_pattern(torch.nn.AvgPool2d) | ||
@register_pattern(torch.nn.MaxPool2d) | ||
@register_pattern(torch.nn.AdaptiveAvgPool2d) | ||
class CopyNode(NoObserver): | ||
def quantize(self, quantizer, node, load_arg): | ||
return quantizer.quantized_graph.node_copy(node, load_arg) | ||
|
||
class IdentityModule(torch.nn.Module): | ||
def forward(self, x): | ||
return x | ||
|
||
# handle conv, maybe followed by bn, maybe followed by relu | ||
@register_pattern(torch.nn.modules.conv.Conv2d) | ||
@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) | ||
@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)) | ||
@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))) | ||
class ConvNormRelu(MinMaxObserver): | ||
def __init__(self, quantizer, node): | ||
super().__init__(quantizer, node) | ||
self.relu_node, self.bn_node = None, None | ||
if isinstance(quantizer.modules[node.target], torch.nn.ReLU): | ||
self.relu_node = node | ||
node = node.args[0] | ||
if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): | ||
self.bn_node = node | ||
self.bn = quantizer.modules[self.bn_node.target] | ||
node = node.args[0] | ||
assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) | ||
self.conv_node = node | ||
self.conv = quantizer.modules[self.conv_node.target] | ||
|
||
def quantize(self, quantizer, node, load_arg): | ||
mod = self.conv | ||
weight, bias = mod.weight, mod.bias | ||
|
||
if self.bn_node is not None: | ||
weight, bias = fuse_conv_bn_weights( | ||
weight, bias, self.bn.running_mean, self.bn.running_var, | ||
self.bn.eps, self.bn.weight, self.bn.bias) | ||
|
||
min_val, max_val = float(weight.min()), float(weight.max()) | ||
|
||
act_scale, act_zp = self.scale_zeropoint() | ||
|
||
weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) | ||
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8) | ||
|
||
ctor = torch.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.nn.quantized.Conv2d | ||
|
||
qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size, | ||
mod.stride, mod.padding, mod.dilation, mod.groups, | ||
mod.bias is not None, mod.padding_mode) | ||
|
||
qconv.set_weight_bias(qweight, bias) | ||
qconv.scale = float(act_scale) | ||
qconv.zero_point = int(act_zp) | ||
parent_name, name = _parent_name(self.conv_node.target) | ||
setattr(quantizer.modules[parent_name], name, qconv) | ||
if self.bn_node is not None: | ||
parent_bn, bn_name = _parent_name(self.bn_node.target) | ||
# we can't just delete this because submodules's forwards (which are not longer use) | ||
# try to call it, so replace with something that does nothing. | ||
setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) | ||
|
||
return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {}) | ||
|
||
|
||
# turn foo.bar -> ['foo', 'bar'] | ||
def _parent_name(target): | ||
r = target.rsplit('.', 1) | ||
if len(r) == 1: | ||
return '', r[0] | ||
else: | ||
return r[0], r[1] | ||
|
||
|
||
|
||
class DefaultQuant(MinMaxObserver): | ||
def quantize(self, input): | ||
assert self.all_tensors | ||
scale, zeropoint = self.scale_zeropoint() | ||
return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node | ||
|
||
def matches(modules, node, pattern, max_uses=sys.maxsize): | ||
if isinstance(pattern, tuple): | ||
self_match, *arg_matches = pattern | ||
else: | ||
self_match = pattern | ||
arg_matches = None | ||
|
||
if node.uses > max_uses: | ||
return False | ||
|
||
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): | ||
if node.op != 'call_module': | ||
return False | ||
if not isinstance(modules[node.target], self_match): | ||
return False | ||
elif callable(self_match): | ||
if node.op != 'call_function' or node.target is not self_match: | ||
return False | ||
elif node.target != self_match: | ||
return False | ||
|
||
if not arg_matches: | ||
return True | ||
|
||
if len(arg_matches) != len(node.args): | ||
return False | ||
|
||
return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) | ||
|
||
|
||
class Quantizer: | ||
def __init__(self, mod, patterns=DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): | ||
self.root = mod.root | ||
self.graph = mod.graph | ||
self.quant_ctor = quant_ctor | ||
|
||
# cached information for observe | ||
self.state_dict = self.root.state_dict() | ||
self.modules = dict(self.root.named_modules()) | ||
|
||
# match the patterns that will get quantized | ||
self.matches = self._find_matches(patterns) | ||
# find _inputs_ to matched nodes that are not quantized, these | ||
# have to be quantized, which requires measuring stats, | ||
# initialize an quant_ctor object for each | ||
self.quants = self._find_quants(quant_ctor) | ||
|
||
|
||
|
||
def observe(self, args): | ||
# most of this function is just an interpreter for the graph | ||
# it would be possible to put this in some abstraction, but | ||
# it is pretty nice to just be able to see exactly what is happening here | ||
# and hack on it. | ||
# maybe we should just provide an example interpreter that people copy/paste | ||
# then edit. | ||
args_iter = iter(args) | ||
env = {} | ||
|
||
def load_arg(a): | ||
return map_arg(a, lambda node: env[node.name]) | ||
|
||
for node in self.graph.nodes: | ||
if node.op == 'placeholder': | ||
result = next(args_iter) | ||
elif node.op == 'get_param': | ||
result = self.state_dict[node.target] | ||
elif node.op == 'call_function': | ||
result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) | ||
elif node.op == 'call_method': | ||
self_obj, *args = load_arg(node.args) | ||
kwargs = load_arg(node.kwargs) | ||
result = getattr(self_obj, node.target)(*args, **kwargs) | ||
elif node.op == 'call_module': | ||
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) | ||
|
||
env[node.name] = result | ||
root_node, obj = self.matches.get(node.name, (None, None)) | ||
if root_node is node: | ||
obj.observe(node, env) | ||
if node.name in self.quants: | ||
self.quants[node.name].observe(node, env) | ||
|
||
return load_arg(self.graph.result) | ||
|
||
def quantize(self): | ||
self.quantized_graph = Graph() | ||
self.delegate = DelegateBase(self.quantized_graph) | ||
|
||
env = {} | ||
quant_env = {} | ||
|
||
def load_arg(n, quantized): | ||
if not quantized: | ||
if n.name not in env and n.name in quant_env: | ||
env[n.name] = Proxy(quant_env[n.name]).dequantize().node | ||
return env[n.name] | ||
else: | ||
if n.name not in quant_env and n.name in env: | ||
quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) | ||
return quant_env[n.name] | ||
|
||
def copy_recursive(node): | ||
def load_or_emit(n): | ||
if n.name in env or e.name in quant_env: | ||
return load_arg(n, quantized=False) | ||
else: | ||
return copy_recusive(n) | ||
r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) | ||
return r | ||
|
||
for node in self.graph.nodes: | ||
root_node, obj = self.matches.get(node.name, (None, None)) | ||
if root_node is None: | ||
# not quantized just copy it | ||
env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) | ||
|
||
elif root_node is node: | ||
r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True))) | ||
if r is NotImplemented: | ||
# quantizer choose to to quantize the node take the entire match, and just copy it over | ||
env[node.name] = copy_recursive(node) | ||
else: | ||
quant_env[node.name] = r | ||
|
||
self.quantized_graph.output(load_arg(self.graph.result, quantized=False)) | ||
return GraphModule(self.root, self.quantized_graph) | ||
|
||
def _find_matches(self, patterns): | ||
modules = dict(self.root.named_modules()) | ||
match_map = {} # node name -> (root_node, match_value?) | ||
|
||
def apply_match(pattern, node, match): | ||
if isinstance(pattern, tuple): | ||
s, *args = pattern | ||
apply_match(s, node, match) | ||
for subpattern, arg in zip(args, node.args): | ||
apply_match(subpattern, arg, match) | ||
else: | ||
match_map[node.name] = match | ||
|
||
for node in reversed(self.graph.nodes): | ||
if node.name not in match_map: | ||
for pattern, value in patterns.items(): | ||
if matches(modules, node, pattern): | ||
apply_match(pattern, node, (node, value(self, node))) | ||
|
||
return match_map | ||
|
||
def _find_quants(self, quant_ctor): | ||
quants = {} | ||
|
||
def visit_arg(n): | ||
# note: we have to measure quantization information | ||
# even for nodes where we might not use it because it is already | ||
# quantized. This is because each match has the option to | ||
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) | ||
if n.name not in quants: | ||
quants[n.name] = quant_ctor(self, n) | ||
for node in self.graph.nodes: | ||
if node.name in self.matches: | ||
map_arg(node.args, visit_arg) | ||
map_arg(node.kwargs, visit_arg) | ||
return quants |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters