Skip to content

Commit

Permalink
RelayViz Graphviz renderer
Browse files Browse the repository at this point in the history
Following apache#10085, this PR adds a
graphviz backend. It requires python `graphviz` package and `dot`
executable in the PATH, similar to `tedd.py`.

This implementation is much like a porting of `visualize` function in
https://tvm.apache.org/2020/07/14/bert-pytorch-tvm, except that
`node_attr_dict` is replaced with a callback `get_node_attr`.

`get_node_attr` can be somehow used to emphasize a set of nodes.
It might be useful if we encounter problems in inferences
and want to find nodes with certain types and attributes.

An example is provided in
https://github.com/chiwwang/tvm/blob/graphviz_renderer_example/test_viz.py

Its outputs are (conv2d with NCHW layout is green-colored):
https://github.com/chiwwang/tvm/blob/graphviz_renderer_example/mod_with_subgraph.pdf
https://github.com/chiwwang/tvm/blob/graphviz_renderer_example/mod_wo_subgraph.pdf
  • Loading branch information
chiwwang committed Feb 27, 2022
1 parent 01f306f commit febe530
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/reference/api/python/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,12 @@ tvm.contrib.relay_viz
~~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.relay_viz
:members:
.. automodule:: tvm.contrib.relay_viz.interface
.. automodule:: tvm.contrib.relay_viz.dot
:members:
.. automodule:: tvm.contrib.relay_viz.terminal
:members:
.. automodule:: tvm.contrib.relay_viz.interface
:members:


tvm.contrib.rocblas
Expand Down
2 changes: 2 additions & 0 deletions gallery/how_to/work_with_relay/using_relay_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
Here we use a renderer rendering graph in the text-form.
It is a lightweight, AST-like visualizer, inspired by `clang ast-dump <https://clang.llvm.org/docs/IntroductionToTheClangAST.html>`_.
We will introduce how to implement customized parsers and renderers through interface classes.
For more details, please refer to :py:mod:`tvm.contrib.relay_viz`.
"""
from typing import (
Dict,
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
TermPlotter,
TermVizParser,
)
from .dot import (
DotPlotter,
DotVizParser,
)


class RelayVisualizer:
Expand Down Expand Up @@ -69,12 +73,16 @@ def __init__(

node_to_id = {}
# callback to generate an unique string-ID for nodes.
# node_count_offset ensure each node ID is still unique across subgraph.
node_count_offset = 0

def traverse_expr(node):
if node in node_to_id:
return
node_to_id[node] = str(len(node_to_id))
node_to_id[node] = str(len(node_to_id) + node_count_offset)

for name in graph_names:
node_count_offset += len(node_to_id)
node_to_id.clear()
relay.analysis.post_order_visit(relay_mod[name], traverse_expr)
graph = self._plotter.create_graph(name)
Expand Down
221 changes: 221 additions & 0 deletions python/tvm/contrib/relay_viz/dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Visualize Relay IR by Graphviz DOT language."""

from typing import (
Any,
Callable,
Dict,
)
from .interface import (
DefaultVizParser,
Plotter,
VizEdge,
VizGraph,
VizNode,
)

try:
import graphviz
except ImportError:
# add "from None" to silence
# "During handling of the above exception, another exception occurred"
raise ImportError(
"The graphviz package is required for DOT renderer. "
"Please install it first. For example, pip3 install graphviz"
) from None

DotVizParser = DefaultVizParser


class DotGraph(VizGraph):
"""DOT graph for relay IR.
See also :py:class:`tvm.contrib.relay_viz.dot.DotPlotter`
Parameters
----------
name: str
name of this graph.
graph_attr: Optional[Dict[str, str]]
key-value pairs for the graph.
node_attr: Optional[Dict[str, str]]
key-value pairs for all nodes.
edge_attr: Optional[Dict[str, str]]
key-value pairs for all edges.
get_node_attr: Optional[Callable[[VizNode], Dict[str, str]]]
A callable returning attributes for the node.
"""

def __init__(
self,
name: str,
graph_attr: Dict[str, str] = None,
node_attr: Dict[str, str] = None,
edge_attr: Dict[str, str] = None,
get_node_attr: Callable[[VizNode], Dict[str, str]] = None,
):
self._name = name
self._get_node_attr = self._default_get_node_attr
if get_node_attr is not None:
self._get_node_attr = get_node_attr

# graphviz recognizes the subgraph as a cluster subgraph
# by the name starting with "cluster" (all lowercase)
self._digraph = graphviz.Digraph(
name=f"cluster_{self._name}",
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr,
)
self._digraph.attr(label=self._name)

def node(self, viz_node: VizNode) -> None:
"""Add a node to the underlying graph.
Nodes in a Relay IR Module are expected to be added in the post-order.
Parameters
----------
viz_node : VizNode
A `VizNode` instance.
"""
self._digraph.node(
viz_node.identity,
f"{viz_node.type_name}\n{viz_node.detail}",
**self._get_node_attr(viz_node),
)

def edge(self, viz_edge: VizEdge) -> None:
"""Add an edge to the underlying graph.
Parameters
----------
viz_edge : VizEdge
A `VizEdge` instance.
"""
self._digraph.edge(viz_edge.start, viz_edge.end)

@property
def digraph(self):
return self._digraph

@staticmethod
def _default_get_node_attr(node: VizNode):
if "Var" in node.type_name:
return {"shape": "ellipse"}
return {"shape": "box"}


class DotPlotter(Plotter):
"""DOT language graph plotter
The plotter accepts various graphviz attributes for graphs, nodes, and edges.
Please refer to https://graphviz.org/doc/info/attrs.html for available attributes.
Parameters
----------
graph_attr: Optional[Dict[str, str]]
key-value pairs for all graphs.
node_attr: Optional[Dict[str, str]]
key-value pairs for all nodes.
edge_attr: Optional[Dict[str, str]]
key-value pairs for all edges.
get_node_attr: Optional[Callable[[VizNode], Dict[str, str]]]
A callable returning attributes for a specific node.
render_kwargs: Optional[Dict[str, Any]]
keyword arguments directly passed to `graphviz.Digraph.render()`.
Examples
--------
.. code-block:: python
from tvm.contrib import relay_viz
from tvm.relay.testing import resnet
mod, param = resnet.get_workload(num_layers=18)
# graphviz attributes
graph_attr = {"color": "red"}
node_attr = {"color": "blue"}
edge_attr = {"color": "black"}
# VizNode is passed to the callback.
# We want to color NCHW conv2d nodes. Also give Var a different shape.
def get_node_attr(node):
if "nn.conv2d" in node.type_name and "NCHW" in node.detail:
return {
"fillcolor": "green",
"style": "filled",
"shape": "box",
}
if "Var" in node.type_name:
return {"shape": "ellipse"}
return {"shape": "box"}
# Create plotter and pass it to viz. Then render the graph.
dot_plotter = relay_viz.DotPlotter(
graph_attr=graph_attr,
node_attr=node_attr,
edge_attr=edge_attr,
get_node_attr=get_node_attr)
viz = relay_viz.RelayVisualizer(
mod,
relay_param=param,
plotter=dot_plotter,
parser=relay_viz.DotVizParser())
viz.render("hello")
"""

def __init__(
self,
graph_attr: Dict[str, str] = None,
node_attr: Dict[str, str] = None,
edge_attr: Dict[str, str] = None,
get_node_attr: Callable[[VizNode], Dict[str, str]] = None,
render_kwargs: Dict[str, Any] = None,
):
self._name_to_graph = {}
self._graph_attr = graph_attr
self._node_attr = node_attr
self._edge_attr = edge_attr
self._get_node_attr = get_node_attr

self._render_kwargs = {} if render_kwargs is None else render_kwargs

def create_graph(self, name):
self._name_to_graph[name] = DotGraph(
name, self._graph_attr, self._node_attr, self._edge_attr, self._get_node_attr
)
return self._name_to_graph[name]

def render(self, filename: str = None):
"""render the graph generated from the Relay IR module.
This function is a thin wrapper of `graphviz.Digraph.render()`.
"""
# Create or update the filename
if filename is not None:
self._render_kwargs["filename"] = filename
# default cleanup
if "cleanup" not in self._render_kwargs:
self._render_kwargs["cleanup"] = True

root_graph = graphviz.Digraph()
for graph in self._name_to_graph.values():
root_graph.subgraph(graph.digraph)
root_graph.render(**self._render_kwargs)
10 changes: 7 additions & 3 deletions python/tvm/contrib/relay_viz/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, node_id: str, node_type: str, node_detail: str):
self._detail = node_detail

@property
def identity(self) -> Union[int, str]:
def identity(self) -> str:
return self._id

@property
Expand All @@ -59,6 +59,10 @@ def type_name(self) -> str:
def detail(self) -> str:
return self._detail

def __repr__(self) -> str:
detail = self._detail.replace("\n", ", ")
return f"VizNode(identity: {self._id}, type_name: {self._type}, detail: {detail}"


class VizEdge:
"""VizEdge connect two `VizNode`.
Expand Down Expand Up @@ -139,7 +143,7 @@ def edge(self, viz_edge: VizEdge) -> None:
Parameters
----------
id_start : VizEdge
viz_edge : VizEdge
A `VizEdge` instance.
"""

Expand Down Expand Up @@ -277,7 +281,7 @@ def _tuple_get_item(
node_id = node_to_id[node]

# Tuple -> TupleGetItemNode
viz_node = VizNode(node_id, f"TupleGetItem", "idx: {node.index}")
viz_node = VizNode(node_id, f"TupleGetItem", f"idx: {node.index}")
viz_edges = [VizEdge(node_to_id[node.tuple_value], node_id)]
return viz_node, viz_edges

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/contrib/relay_viz/terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
VizEdge,
VizGraph,
VizNode,
VizParser,
)


class TermVizParser(DefaultVizParser):
class TermVizParser(VizParser):
"""`TermVizParser` parse nodes and edges for `TermPlotter`."""

def __init__(self):
Expand Down Expand Up @@ -166,7 +167,7 @@ def edge(self, viz_edge: VizEdge) -> None:
Parameters
----------
id_start : VizEdge
viz_edge : VizEdge
A `VizEdge` instance.
"""
# Take CallNode as an example, instead of "arguments point to CallNode",
Expand Down

0 comments on commit febe530

Please sign in to comment.