Skip to content

Commit

Permalink
[quant][pt2] Fix and rename move_model_to_eval (pytorch#108891)
Browse files Browse the repository at this point in the history
Summary:
This commit fixes two silent correctness problems with
the current implementation of `move_model_to_eval`:

(1) Previously the user had to manually call `eliminate_dead_code`
before calling `move_model_to_eval`, otherwise the dropout pattern
won't actually get eliminated. This is because subgraph rewriter
complains the match is not self-contained, and so silently does
not do the replacement.

(2) We wish to error when the user calls `model.train()` or
`model.eval()` on an exported model. This error is raised
correctly immediately after export today, but no longer raised
after the user calls prepare or convert.

We fix (1) by moving the `eliminate_dead_code` call into
`move_model_to_eval`, and fix (2) by ensuring the respective
errors are thrown after prepare and convert as well.

Additionally, this commit renames `move_model_to_eval` to
`move_exported_model_to_eval` to be more explicit.

bypass-github-export-checks

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_disallow_eval_train
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_to_eval

Imported from OSS

Differential Revision: D49097293

Pull Request resolved: pytorch#108891
Approved by: https://github.com/jerryzh168
  • Loading branch information
andrewor14 authored and pytorchmergebot committed Sep 11, 2023
1 parent 57e5239 commit e8a402c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 11 deletions.
3 changes: 2 additions & 1 deletion test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _test_common(
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
prepare_model = prepare_pt2e(export_model, quantizer)
prepare_model(*inputs)
convert_model = convert_pt2e(prepare_model).eval()
convert_model = convert_pt2e(prepare_model)
torch.ao.quantization.move_exported_model_to_eval(convert_model)
_ = torch.compile(convert_model)(*inputs)
self.assertEqual(
counters["inductor"]["pattern_matcher_count"], matcher_count
Expand Down
38 changes: 33 additions & 5 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper(
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)

if verify_convert:
torch.ao.quantization.move_model_to_eval(model_pt2e)
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
Expand Down Expand Up @@ -2431,7 +2431,7 @@ def forward(self, x, y):
non_ref_node_occurrence
)

def test_move_model_to_eval(self):
def test_move_exported_model_to_eval(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -2443,8 +2443,6 @@ def forward(self, x):
example_inputs = (torch.randn(1),)
m = M().train()
m = capture_pre_autograd_graph(m, example_inputs)
m.graph.eliminate_dead_code()
m.recompile()

# Assert that dropout op exists and is in train mode
dropout_node = None
Expand All @@ -2456,13 +2454,43 @@ def forward(self, x):
self.assertTrue(dropout_node.args[2])

# Do the subgraph rewriting
torch.ao.quantization.move_model_to_eval(m)
torch.ao.quantization.move_exported_model_to_eval(m)

# Assert that dropout op is now replaced with a clone op
targets = [n.target for n in m.graph.nodes]
self.assertTrue(torch.ops.aten.clone.default in targets)
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)

def test_disallow_eval_train(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
example_inputs = (torch.rand(3, 3, 5, 5),)

# Before export: this is OK
m.eval()
m.train()

# After export: this is not OK
m = capture_pre_autograd_graph(m, example_inputs)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()

# After prepare: still not OK
quantizer = XNNPACKQuantizer()
m = prepare_qat_pt2e(m, quantizer)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()

# After convert: still not OK
m = convert_pt2e(m)
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
m.train()


@skipIfNoQNNPACK
class TestQuantizePT2EOps(QuantizationTestCase):
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .quantize import * # noqa: F403
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
from .pt2e.eval_utils import _move_model_to_eval as move_model_to_eval
from .pt2e.eval_utils import _move_exported_model_to_eval as move_exported_model_to_eval
from typing import Union, List, Callable, Tuple, Optional
from torch import Tensor
import torch
Expand Down Expand Up @@ -120,7 +120,7 @@
"get_quantized_operator",
"get_static_quant_module_class",
"load_observer_state_dict",
"move_model_to_eval",
"move_exported_model_to_eval",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
"prepare",
Expand Down
8 changes: 6 additions & 2 deletions torch/ao/quantization/pt2e/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def _replace_dropout_for_eval(m: torch.fx.GraphModule):
# Avoid circular dependencies
from .utils import get_aten_graph_module

# Needed to ensure subgraph matches are self-contained
m.graph.eliminate_dead_code()
m.recompile()

def dropout_train(x):
return F.dropout(x, p=0.5, training=True)

Expand All @@ -39,9 +43,9 @@ def dropout_eval(x):
m.recompile()


# TODO: also support move_model_to_train
# TODO: also support move_exported_model_to_train
# TODO: also support standalone batchnorm
def _move_model_to_eval(model: torch.fx.GraphModule):
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
"""
Move an exported GraphModule to eval mode.
Expand Down
21 changes: 20 additions & 1 deletion torch/ao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import operator
import types

import torch
from torch._export import capture_pre_autograd_graph
from torch.fx import (
GraphModule,
Node,
)
from torch.nn.utils.fusion import fuse_conv_bn_weights
import operator
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec

Expand Down Expand Up @@ -431,3 +433,20 @@ def replacement(x_i8, scale, zero_point, quant_min, quant_max):
new_args = tuple(new_args)
node.args = new_args
return gm

# TODO: Handle this in export itself and don't wrap the model in another GraphModule
# in prepare and convert
def _disallow_eval_train(model: GraphModule):
"""
Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
This is useful for exported models, where these methods don't actually behave as expected.
"""
def _train(self, mode: bool = True):
raise NotImplementedError("Calling train() is not supported yet.")

def _eval(self, mode: bool = True):
raise NotImplementedError("Calling eval() is not supported yet.")

model.train = types.MethodType(_train, model) # type: ignore[method-assign]
model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
return model
4 changes: 4 additions & 0 deletions torch/ao/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .pt2e.utils import (
_get_node_name_to_scope,
_fuse_conv_bn_,
_disallow_eval_train,
)
from .pt2e.representation import reference_representation_rewrite
from .fx.prepare import prepare as fx_prepare
Expand Down Expand Up @@ -72,6 +73,7 @@ def prepare_pt2e(
quantizer.validate(model)
model = prepare(model, node_name_to_scope, is_qat=False)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

def prepare_qat_pt2e(
Expand All @@ -88,6 +90,7 @@ def prepare_qat_pt2e(
_fuse_conv_bn_qat(model)
model = prepare(model, node_name_to_scope, is_qat=True)
model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

def convert_pt2e(
Expand All @@ -106,4 +109,5 @@ def convert_pt2e(
model = reference_representation_rewrite(model)

model.meta.update(original_graph_meta)
model = _disallow_eval_train(model)
return model

0 comments on commit e8a402c

Please sign in to comment.