Skip to content

Commit

Permalink
[JIT] Revert Freezing shared type PR (pytorch#46285)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#45902 by reverting pytorch#42457

The test case introduced by pytorch#42457 was fixed by pytorch#46250, which I'm assuming is the real source of the bug.

In the future it would be good to provide repro's for freezing issues without including a quantization dependency; there was another another issue in freezing (see: pytorch#46054) who's root cause was the same quantization issue pytorch#46250.

Pull Request resolved: pytorch#46285

Reviewed By: bdhirsh

Differential Revision: D24288739

Pulled By: eellison

fbshipit-source-id: b69ee8c713f749cd93d5eba370c3eafed86568bb
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Oct 15, 2020
1 parent b547973 commit 908c235
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
26 changes: 26 additions & 0 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,3 +1195,29 @@ def _static_quant(model):
# It used to segfault while running frozen module.
m_frozen_res = m_frozen(data)
self.assertEqual(m_res, m_frozen_res)

def test_module_getattr_indirection(self):
@torch.jit.script
class ValHolder(object):
def __init__(self, val: int):
self.val: int = val

class Mod(nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.mod1 = ValHolder(1)
self.mod2 = ValHolder(2)

def forward(self, cond: bool):
if cond:
mod = self.mod1
else:
mod = self.mod2
return mod.val

mod = Mod()
mod.eval()
frozen_mod = torch.jit.freeze(torch.jit.script(mod))
mod_eager = Mod()
self.assertEqual(mod_eager(True), frozen_mod(True))
self.assertEqual(mod_eager(False), frozen_mod(False))
48 changes: 22 additions & 26 deletions torch/csrc/jit/passes/freeze_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,6 @@ namespace torch {
namespace jit {

namespace {
ModulePtr getModulePtrForGetAttrNode(
const Node* node,
const std::shared_ptr<Graph>& graph,
const Module& graph_input_module) {
std::vector<std::string> names;
names.clear();
while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
TORCH_INTERNAL_ASSERT(
node->kind() == prim::GetAttr, "Expected prim::GetAttr nodes");
names.insert(names.begin(), node->s(attr::name));
node = node->inputs()[0]->node();
}
// Copy/paste from quantization/helper.h
Module m = graph_input_module;
for (const auto& p : names) {
m = m.attr(p).toModule();
}
return m._ivalue();
}

class AttributePropagator {
public:
Expand Down Expand Up @@ -553,7 +534,7 @@ class AttributePropagator {
removeUnusedAttrs();
}

// Prepraring for clean up phase. At this point, record all subModules that
// Prepraring for clean up phase. At this point, record all subModules that
// contains mutable attributes.
void recordReferencedAttrs(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks({graph->block()});
Expand All @@ -567,12 +548,27 @@ class AttributePropagator {
}
if (n->kind() == prim::GetAttr) {
auto& name = n->s(attr::name);
auto mptr =
getModulePtrForGetAttrNode(n->input(0)->node(), graph, module_);
auto module = Module(mptr);
if (module.type() == n->inputs()[0]->type() && module.hasattr(name)) {
auto attr = module.attr(name);
insertMutableAttr(name, attr, mptr);
// For now, use all module ivalues which are the same type
// and could be the module that this GetAttr resolves to
// TODO: we could attempt to follow the GetAttr chain and
// find the exact ivalue, we would have to be careful
// that the chain does not contain any attributes which
// get written to (setAttr calls)
for (auto& mptr : modules) {
auto module = Module(mptr);
if (module.type() == n->inputs()[0]->type()) {
TORCH_INTERNAL_ASSERT(module.hasattr(name));
auto module = Module(mptr);
auto attr = module.attr(name);
// TODO: this could be insertReferencedAttr to be more clear,
// these are attributes we could not inline, which include
// other reasons besides mutation (unsupported constant,
// getAttr resolving to non-getAttr node, etc)
insertMutableAttr(name, attr, mptr);
if (attr.isModule()) {
modules.insert(attr.toModule()._ivalue());
}
}
}
} else if (n->kind() == prim::fork) {
applyToForkSubgraph(
Expand Down

0 comments on commit 908c235

Please sign in to comment.