Skip to content

Commit

Permalink
define get_dot_graph (pytorch#70541)
Browse files Browse the repository at this point in the history
Summary:
In the [docstring](https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/graph_drawer.py#L54-L60) we mention `get_dot_graph but it is not defined, so I defined it here.
Not sure if this is preferred, or should we update the docstring to use `get_main_dot_graph`

Pull Request resolved: pytorch#70541

Test Plan:
```
            g = FxGraphDrawer(symbolic_traced, "resnet18")
            with open("a.svg", "w") as f:
                f.write(g.get_dot_graph().create_svg())
```

Reviewed By: khabinov

Differential Revision: D33378080

Pulled By: mostafaelhoushi

fbshipit-source-id: 7feea2425a12d5628ddca15beff0fe5110f4a111
  • Loading branch information
mostafaelhoushi authored and facebook-github-bot committed Jan 6, 2022
1 parent 917d56a commit 3f53365
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torch/fx/passes/graph_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def __init__(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr

self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(leaf_node, f"{name}_{node.target}", ignore_getattr)

def get_dot_graph(self, submod_name=None) -> pydot.Dot:
if submod_name is None:
return self.get_main_dot_graph()
else:
return self.get_submod_dot_graph(submod_name)

def get_main_dot_graph(self) -> pydot.Dot:
return self._dot_graphs[self._name]

Expand Down

0 comments on commit 3f53365

Please sign in to comment.