Skip to content

Commit

Permalink
dbr quant: add torchscript pass to remove redundant aliases (pytorch#…
Browse files Browse the repository at this point in the history
…71230)

Summary:
Pull Request resolved: pytorch#71230

DBR quantization uses `torch.Tensor.as_subclass` frequently. When
the quantized model is traced with `torch.jit.trace`, these calls appear
in the resulting graph as `aten::alias`. This PR adds a pass to remove
these calls from the graph, for two reasons:
1. ease of debugging (these calls do nothing)
2. less work for downstream passes (for example, converting to ONNX currently breaks if these alias calls are present)

For now, we have to inline the graph in order for `aliasDb` to determine
safety properly. In the future, we may choose to relax this if there is
a need for it.

Test Plan:
Test plan is pretty basic for now, it can be improved in future PRs.
```
python test/test_quantization.py TestQuantizeDBR.test_jit_tracing_removes_aliases
```

Reviewed By: eellison

Differential Revision: D33552387

Pulled By: vkuzo

fbshipit-source-id: 681a33ddfff394a91e971263ac593afd93c5ea78
(cherry picked from commit 0f84127)
  • Loading branch information
vkuzo authored and pytorchmergebot committed Mar 3, 2022
1 parent eb8d065 commit bf896a2
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 0 deletions.
46 changes: 46 additions & 0 deletions test/quantization/dbr/test_quantize_dbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import torch.ao.ns._numeric_suite_dbr as ns
# TODO(future PR): move these utils out of the FX folder
import torch.ao.ns._numeric_suite_fx as ns_fx
from torch.ao.quantization._dbr.torchscript_utils import (
remove_redundant_aliases,
)

def _allclose(a, b):
if isinstance(a, tuple):
Expand Down Expand Up @@ -1303,6 +1306,34 @@ def forward(self, x):
input_shape = (1, 1, 1, 1)
self._test_serialization(M, input_shape)

def test_jit_tracing_removes_aliases(self):
m = nn.Sequential(
nn.Conv2d(1, 1, 1),
nn.Sequential(
nn.Conv2d(1, 1, 1),
),
)
qconfig_dict = {'': torch.quantization.default_qconfig}
example_args = (torch.randn(1, 1, 1, 1),)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
mq = _quantize_dbr.convert(mp)
mqs = torch.jit.trace(mq, example_args)
FileCheck().check_count("aten::alias", 5, exactly=True).run(
mqs.inlined_graph)
res1 = mqs(*example_args)
mqs = remove_redundant_aliases(mqs)
res2 = mqs(*example_args)
self.assertTrue(torch.allclose(res1, res2))
# TODO(future PR): figure out why aliasing still appears in the inlined
# graph, and if that is fixed then just check the inlined graph.
for graph in (
mqs.graph,
getattr(mqs, '1').graph,
getattr(getattr(mqs, '1'), '0').graph,
):
FileCheck().check_count("aten::alias", 0, exactly=True).run(graph)


@skipIfNoFBGEMM
class TestQuantizeDBRMultipleOps(QuantizeDBRTestCase):
"""
Expand Down Expand Up @@ -1543,3 +1574,18 @@ def test_mobilenet_v2(self):
m, qconfig, (torch.randn(1, 3, 224, 224),),
# TODO fix this (reason TBD)
do_torchscript_checks=False)

@skip_if_no_torchvision
def test_mobilenet_v2_removes_aliases(self):
import torchvision
m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False)\
.eval().float()
qconfig_dict = {'': torch.quantization.default_qconfig}
example_args = (torch.randn(1, 3, 224, 224),)
mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
mq = _quantize_dbr.convert(mp)
mqs = torch.jit.trace(mq, example_args)
res1 = mqs(*example_args)
mqs = remove_redundant_aliases(mqs)
res2 = mqs(*example_args)
self.assertTrue(torch.allclose(res1, res2))
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/passes/remove_mutation.cpp",
"torch/csrc/jit/passes/prepack_folding.cpp",
"torch/csrc/jit/passes/fold_conv_bn.cpp",
"torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp",
"torch/csrc/jit/passes/frozen_concat_linear.cpp",
"torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp",
"torch/csrc/jit/passes/frozen_conv_folding.cpp",
Expand Down
15 changes: 15 additions & 0 deletions torch/ao/quantization/_dbr/torchscript_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from torch.jit._recursive import wrap_cpp_module

def remove_redundant_aliases(scripted_module: torch.nn.Module):
"""
Running torch.jit.trace on a model with DBR quantization introduces
extra alias ops, because we use `torch.Tensor.as_subclass` and tracing
through this results in an `aten::alias` function call in TorchScript.
This pass removes these alias calls when it is safe to do so.
"""
module_c = scripted_module._c
module_c = \
torch._C._jit_pass_dbr_quant_remove_redundant_aliases(module_c) # type: ignore[attr-defined]
scripted_module = wrap_cpp_module(module_c)
return scripted_module
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>

#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/quantization/helper.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>

namespace torch {
namespace jit {

namespace {

void DBRQuantRemoveRedundantAliasesImpl(const Method& method) {
auto g = method.graph();
const bool is_frozen = false;
const bool descend_function_calls = true;
AliasDb alias_db(g, is_frozen, descend_function_calls);
// find the alias nodes
std::vector<Node*> alias_nodes;
DepthFirstGraphNodeIterator it(g);
Node* node = nullptr;
while ((node = it.next()) != nullptr) {
if (node->kind() == Symbol::aten("alias")) {
alias_nodes.push_back(node);
}
}

// remove the alias nodes, if it is safe to do so
for (auto* node : alias_nodes) {
GRAPH_DEBUG(*node);

Value* input_value = node->input();
Value* output_value = node->output();

bool always_safe_to_mutate = alias_db.safeToChangeAliasingRelationship(
node->inputs(), node->outputs());

const auto g_in = g->inputs();
const auto g_out = g->outputs();
bool is_input =
std::find(g_in.begin(), g_in.end(), input_value) != g_in.end();
bool is_output =
std::find(g_out.begin(), g_out.end(), output_value) != g_out.end();
// We assume that aliasing is safe to update on inputs and outputs if they
// do not have writers.
bool input_safe_to_mutate =
(is_input && !alias_db.hasWriters(input_value) &&
!alias_db.hasWriters(output_value));
bool output_safe_to_mutate =
(is_output && !alias_db.hasWriters(input_value) &&
!alias_db.hasWriters(output_value));

if (always_safe_to_mutate || input_safe_to_mutate ||
output_safe_to_mutate) {
output_value->replaceAllUsesWith(input_value);
node->destroy();
}
}
}

} // namespace

Module DBRQuantRemoveRedundantAliases(Module& module) {
for (const auto& child : module.modules()) {
for (const auto& method : child.get_methods()) {
DBRQuantRemoveRedundantAliasesImpl(method);
}
}

return module;
}

} // namespace jit
} // namespace torch
21 changes: 21 additions & 0 deletions torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <torch/csrc/jit/api/module.h>

namespace torch {
namespace jit {

// This function replaces instances of
//
// %b = aten::alias(%a)
// %c = foo(%b)
//
// with
//
// %c = foo(%a)
//
// on the module forward, if it's safe to do so.
TORCH_API Module DBRQuantRemoveRedundantAliases(Module& module);

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
#include <torch/csrc/jit/passes/create_functional_graphs.h>
#include <torch/csrc/jit/passes/cuda_graph_fuser.h>
#include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/device_type_analysis.h>
Expand Down Expand Up @@ -267,6 +268,9 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_fold_convbn",
[](Module& module) { return FoldConvBatchNorm(module); })
.def(
"_jit_pass_dbr_quant_remove_redundant_aliases",
[](Module& module) { return DBRQuantRemoveRedundantAliases(module); })
.def(
"_freeze_module",
[](Module& module,
Expand Down

0 comments on commit bf896a2

Please sign in to comment.