forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dbr quant: add torchscript pass to remove redundant aliases (pytorch#…
…71230) Summary: Pull Request resolved: pytorch#71230 DBR quantization uses `torch.Tensor.as_subclass` frequently. When the quantized model is traced with `torch.jit.trace`, these calls appear in the resulting graph as `aten::alias`. This PR adds a pass to remove these calls from the graph, for two reasons: 1. ease of debugging (these calls do nothing) 2. less work for downstream passes (for example, converting to ONNX currently breaks if these alias calls are present) For now, we have to inline the graph in order for `aliasDb` to determine safety properly. In the future, we may choose to relax this if there is a need for it. Test Plan: Test plan is pretty basic for now, it can be improved in future PRs. ``` python test/test_quantization.py TestQuantizeDBR.test_jit_tracing_removes_aliases ``` Reviewed By: eellison Differential Revision: D33552387 Pulled By: vkuzo fbshipit-source-id: 681a33ddfff394a91e971263ac593afd93c5ea78 (cherry picked from commit 0f84127)
- Loading branch information
1 parent
eb8d065
commit bf896a2
Showing
6 changed files
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
from torch.jit._recursive import wrap_cpp_module | ||
|
||
def remove_redundant_aliases(scripted_module: torch.nn.Module): | ||
""" | ||
Running torch.jit.trace on a model with DBR quantization introduces | ||
extra alias ops, because we use `torch.Tensor.as_subclass` and tracing | ||
through this results in an `aten::alias` function call in TorchScript. | ||
This pass removes these alias calls when it is safe to do so. | ||
""" | ||
module_c = scripted_module._c | ||
module_c = \ | ||
torch._C._jit_pass_dbr_quant_remove_redundant_aliases(module_c) # type: ignore[attr-defined] | ||
scripted_module = wrap_cpp_module(module_c) | ||
return scripted_module |
74 changes: 74 additions & 0 deletions
74
torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h> | ||
|
||
#include <torch/csrc/jit/ir/alias_analysis.h> | ||
#include <torch/csrc/jit/jit_log.h> | ||
#include <torch/csrc/jit/passes/quantization/helper.h> | ||
#include <torch/csrc/jit/runtime/graph_iterator.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
namespace { | ||
|
||
void DBRQuantRemoveRedundantAliasesImpl(const Method& method) { | ||
auto g = method.graph(); | ||
const bool is_frozen = false; | ||
const bool descend_function_calls = true; | ||
AliasDb alias_db(g, is_frozen, descend_function_calls); | ||
// find the alias nodes | ||
std::vector<Node*> alias_nodes; | ||
DepthFirstGraphNodeIterator it(g); | ||
Node* node = nullptr; | ||
while ((node = it.next()) != nullptr) { | ||
if (node->kind() == Symbol::aten("alias")) { | ||
alias_nodes.push_back(node); | ||
} | ||
} | ||
|
||
// remove the alias nodes, if it is safe to do so | ||
for (auto* node : alias_nodes) { | ||
GRAPH_DEBUG(*node); | ||
|
||
Value* input_value = node->input(); | ||
Value* output_value = node->output(); | ||
|
||
bool always_safe_to_mutate = alias_db.safeToChangeAliasingRelationship( | ||
node->inputs(), node->outputs()); | ||
|
||
const auto g_in = g->inputs(); | ||
const auto g_out = g->outputs(); | ||
bool is_input = | ||
std::find(g_in.begin(), g_in.end(), input_value) != g_in.end(); | ||
bool is_output = | ||
std::find(g_out.begin(), g_out.end(), output_value) != g_out.end(); | ||
// We assume that aliasing is safe to update on inputs and outputs if they | ||
// do not have writers. | ||
bool input_safe_to_mutate = | ||
(is_input && !alias_db.hasWriters(input_value) && | ||
!alias_db.hasWriters(output_value)); | ||
bool output_safe_to_mutate = | ||
(is_output && !alias_db.hasWriters(input_value) && | ||
!alias_db.hasWriters(output_value)); | ||
|
||
if (always_safe_to_mutate || input_safe_to_mutate || | ||
output_safe_to_mutate) { | ||
output_value->replaceAllUsesWith(input_value); | ||
node->destroy(); | ||
} | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
Module DBRQuantRemoveRedundantAliases(Module& module) { | ||
for (const auto& child : module.modules()) { | ||
for (const auto& method : child.get_methods()) { | ||
DBRQuantRemoveRedundantAliasesImpl(method); | ||
} | ||
} | ||
|
||
return module; | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
21 changes: 21 additions & 0 deletions
21
torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/api/module.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
// This function replaces instances of | ||
// | ||
// %b = aten::alias(%a) | ||
// %c = foo(%b) | ||
// | ||
// with | ||
// | ||
// %c = foo(%a) | ||
// | ||
// on the module forward, if it's safe to do so. | ||
TORCH_API Module DBRQuantRemoveRedundantAliases(Module& module); | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters