Skip to content

[torch.fx] Fix pattern matching the same node multiple times #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: thomas/fix_replace_pattern_in_torch_fx
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions test/fx/test_subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,88 @@ def forward(self, x):
if n.op == 'placeholder':
assert n.type == int
assert m.type == int

def test_subgraph_writer_replace_consecutive_submodules(self):

def f(x):
x = torch.sigmoid(x)
x = torch.sigmoid(x)
return torch.sigmoid(x)

def pattern(x):
return torch.sigmoid(x)

def replacement(x):
return torch.exp(x)

def comparison(x):
x = torch.exp(x)
x = torch.exp(x)
return torch.exp(x)

traced = symbolic_trace(f)
comparison_fn = symbolic_trace(comparison)

x = torch.randn(3, 4)

subgraph_rewriter.replace_pattern(traced, pattern, replacement)

traced.graph.lint()

ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_replaces_parallel_functions(self):
def f(x):
y = torch.sigmoid(x)
z = torch.sigmoid(x)
return y, z

def pattern(x):
return torch.sigmoid(x)

def replacement(x):
return torch.relu(x)

def comparison(x):
y = torch.relu(x)
z = torch.relu(x)
return y, z

traced = symbolic_trace(f)

subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()

x = torch.randn(3, 4)
ref_outs = comparison(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)

def test_subgraph_rewriter_replaces_parallel_functions_when_aggregated(self):
def f(x):
y = torch.sigmoid(x)
z = torch.sigmoid(x)
return y + z

def pattern(x):
return torch.sigmoid(x)

def replacement(x):
return torch.relu(x)

def comparison(x):
y = torch.relu(x)
z = torch.relu(x)
return y + z

traced = symbolic_trace(f)

subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()

x = torch.randn(3, 4)
ref_outs = comparison(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
198 changes: 106 additions & 92 deletions torch/fx/subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,30 @@ def __init__(self, pattern: Graph) -> None:
assert len(self.pattern_anchor.all_input_nodes) == 1, \
"Pattern matching on multiple outputs is not supported"
# Maps nodes in the pattern subgraph to nodes in the larger graph
self.nodes_map: Dict[Node, Node] = {}
self.nodes_map: List[Dict[Node, Node]] = [{}]

def matches_subgraph_from_anchor(self, anchor: Node) -> bool:
def matches_subgraph_from_anchor(self, anchor: Node) -> List[Dict[Node, Node]]:
"""
Checks if the whole pattern can be matched starting from
``anchor`` in the larger graph.

Pattern matching is done by recursively comparing the pattern
node's use-def relationships against the graph node's.
"""
self.nodes_map = {}
return self._match_nodes(self.pattern_anchor, anchor)
self.nodes_map: List[Dict[Node, Node]] = [{}]
self._match_nodes(self.pattern_anchor, anchor)

# We need to filter out the one that are empty
self.nodes_map = [elt for elt in self.nodes_map if len(elt) > 0]
return self.nodes_map

# Compare the pattern node `pn` against the graph node `gn`
def _match_nodes(self, pn: Node, gn: Node) -> bool:
def _match_nodes(self, pn: Node, gn: Node, graph_id: int = 0) -> bool:

# Check if we've already matched these nodes in the current
# traversal
if pn in self.nodes_map:
return self.nodes_map[pn] == gn
if pn in self.nodes_map[graph_id]:
return self.nodes_map[graph_id][pn] == gn

def attributes_are_equal(pn: Node, gn: Node) -> bool:
# Use placeholder and output nodes as wildcards. The
Expand All @@ -63,7 +67,7 @@ def attributes_are_equal(pn: Node, gn: Node) -> bool:
return False

# Optimistically mark `pn` as a match for `gn`
self.nodes_map[pn] = gn
self.nodes_map[graph_id][pn] = gn

# Traverse the use-def relationships to ensure that `pn` is a true
# match for `gn`
Expand All @@ -73,14 +77,22 @@ def attributes_are_equal(pn: Node, gn: Node) -> bool:
and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
return False
if pn.op == "output":
match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_)
for gn_ in gn.all_input_nodes)
# Only the first graph compares the output.
assert graph_id == 0
# We broadcast the result to all the other potential graph matching.
self.nodes_map += [copy.copy(self.nodes_map[graph_id]) for _ in range(len(gn.all_input_nodes) - 1)]
all_matches = tuple(self._match_nodes(pn.all_input_nodes[0], gn_, graph_id_)
for graph_id_, gn_ in enumerate(gn.all_input_nodes)
)
self.nodes_map = [node_map for node_map, match in zip(self.nodes_map, all_matches) if match]
# This is not really needed to return that value
return any(all_matches)
else:
match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes)
and all(self._match_nodes(pn_, gn_) for pn_, gn_
and all(self._match_nodes(pn_, gn_, graph_id) for pn_, gn_
in zip(pn.all_input_nodes, gn.all_input_nodes)))
if not match_found:
self.nodes_map.pop(pn)
self.nodes_map[graph_id].pop(pn)
return False

return True
Expand Down Expand Up @@ -256,64 +268,67 @@ def forward(self, x, w1, w2):
matcher = _SubgraphMatcher(pattern_graph)
matches: List[Match] = []

# Consider each node as an "anchor" (deepest matching graph node)
for anchor in original_graph.nodes:
def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
for n in lookup.keys():

if matcher.matches_subgraph_from_anchor(anchor):

def pattern_is_contained(nodes_map : Dict[Node, Node]) -> bool:
# `lookup` represents all the nodes in `original_graph`
# that are part of `pattern`
lookup: Dict[Node, Node] = {v : k for k, v
in nodes_map.items()}
for n in lookup.keys():

# Nodes that can "leak"...

# Placeholders (by definition)
if n.op == "placeholder":
continue
# Pattern output (acts as a container)
if lookup[n].op == "output":
continue
# Result contained by pattern output (what we'll
# hook in to the new Graph, thus what we'll
# potentially use in other areas of the Graph as
# an input Node)
if (len(lookup[n].users) == 1
and list(lookup[n].users.keys())[0].op == "output"):
continue

for user in n.users:
# If this node has users that were not in
# `lookup`, then it must leak out of the
# pattern subgraph
if user not in lookup:
return False
return True
# Nodes that can "leak"...

# Placeholders (by definition)
if n.op == "placeholder":
continue
# Pattern output (acts as a container)
if lookup[n].op == "output":
continue
# Placeholders (by definition)
if lookup[n].op == "placeholder":
continue
# Result contained by pattern output (what we'll
# hook in to the new Graph, thus what we'll
# potentially use in other areas of the Graph as
# an input Node)
if (len(lookup[n].users) == 1
and list(lookup[n].users.keys())[0].op == "output"):
continue

# It's not a match if the pattern leaks out into the rest
# of the graph
if pattern_is_contained(matcher.nodes_map):
for k, v in matcher.nodes_map.items():
# Shallow copy nodes_map
matches.append(Match(anchor=anchor,
nodes_map=copy.copy(matcher.nodes_map)))
for user in n.users:
# If this node has users that were not in
# `lookup`, then it must leak out of the
# pattern subgraph
if user not in lookup:
return False
return True

# Consider each node as an "anchor" (deepest matching graph node)
for anchor in original_graph.nodes:
potential_matches = matcher.matches_subgraph_from_anchor(anchor)
# It's not a match if the pattern leaks out into the rest
# of the graph
for node_map in potential_matches:
if pattern_is_contained(node_map):
# Shallow copy nodes_map
matches.append(Match(anchor=anchor,
nodes_map=copy.copy(node_map)))

# The set of all nodes in `original_graph` that we've seen thus far
# as part of a pattern match
replaced_nodes: Set[Node] = set()
# As we progressively replace node, we need to keep track on how the match results need to change also
match_changed_node: Dict[Node, Node] = dict()

# Return True if one of the nodes in the current match has already
# been used as part of another match
def overlaps_with_prev_match(match: Match) -> bool:
for n in match.nodes_map.values():
if n in replaced_nodes and n.op != "placeholder":
for pn, gn in match.nodes_map.items():
if pn.op in ["placeholder", "output"]:
continue
if gn in replaced_nodes and gn.op != "placeholder":
return True
return False

for match in matches:

for i, match in enumerate(matches):
# Skip overlapping matches
if overlaps_with_prev_match(match):
continue
Expand All @@ -327,7 +342,7 @@ def overlaps_with_prev_match(match: Match) -> bool:
replacement_placeholders = [n for n in replacement_graph.nodes
if n.op == "placeholder"]
assert len(pattern_placeholders) == len(replacement_placeholders)
placeholder_map = {r : p for r, p
placeholder_map = {r: p for r, p
in zip(replacement_placeholders, pattern_placeholders)}

# node from `original_graph` that matched with the output node
Expand All @@ -341,15 +356,17 @@ def mark_node_as_replaced(n: Node) -> None:
mark_node_as_replaced(n_)
replaced_nodes.add(n)

mark_node_as_replaced(subgraph_output)
for input_node in subgraph_output.all_input_nodes:
mark_node_as_replaced(input_node)

# Intialize `val_map` with mappings from placeholder nodes in
# Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph`
for replacement_node in replacement_placeholders:
# Get the `original_graph` placeholder node
# corresponding to the current `replacement_node`
pattern_node = placeholder_map[replacement_node]
original_graph_node = match.nodes_map[pattern_node]
original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node])

# Populate `val_map`
val_map[replacement_node] = original_graph_node

Expand All @@ -361,39 +378,36 @@ def mark_node_as_replaced(n: Node) -> None:
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location

# CASE 1: We need to hook the replacement subgraph in somewhere
# in the middle of the graph. We replace the Node in the
# original graph that corresponds to the end of the pattern
# subgraph
if subgraph_output.op != "output":
# `subgraph_output` may have multiple args. These args could
# be from the orignal graph, or they could have come from
# the insertion of `replacement_subgraph`. We need to find
# the Node that was originally matched as part of
# `pattern` (i.e. a Node from the original graph). We can
# figure this out by looking in `match.nodes_map`. The map
# was created before `replacement_subgraph` was spliced in,
# so we know that, if a Node is in `match.nodes_map.values`,
# it must have come from the original graph
for n in subgraph_output.all_input_nodes:
if (n.op != "placeholder"
and n in match.nodes_map.values()):
subgraph_output = n
break
assert subgraph_output.op != "output"
# CASE 2: The pattern subgraph match extends to the end of the
# original graph, so we need to change the current graph's
# output Node to reflect the insertion of the replacement graph.
# We'll keep the current output Node, but update its args and
# `_input_nodes` as necessary
else:
subgraph_output.args = ((copied_output,))
if isinstance(copied_output, Node):
subgraph_output._input_nodes = {copied_output: None}
pattern_outputs = [n for n in pattern_graph.nodes
if n.op == "output"]
assert len(pattern_outputs)
replacement_outputs = [n for n in replacement_graph.nodes
if n.op == "output"]
assert len(replacement_outputs) == len(pattern_outputs)
outputs_map = {p: r for r, p
in zip(replacement_outputs, pattern_outputs)}

for pn, gn in match.nodes_map.items():
if gn.op == "placeholder":
continue

assert isinstance(copied_output, Node)
subgraph_output.replace_all_uses_with(copied_output)
# We search for the node corresponding to the output of the pattern.
if pn.op != "output":
continue

# the anchor should correspond to `subgraph_output`
assert subgraph_output == gn

# We update all anchor inputs to the new nodes
rn = outputs_map[pn]
for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes):
gn_input = match.nodes_map[pn_input]
rn_input_in_original_graph = val_map[rn_input]
gn_input.replace_all_uses_with(rn_input_in_original_graph)
# We store the updated node point in case other nodes want to use it
match_changed_node[gn_input] = rn_input_in_original_graph

assert isinstance(copied_output, Node)
# Erase the `pattern` nodes
for node in reversed(original_graph.nodes):
if len(node.users) == 0 and node.op != "output":
Expand Down