Skip to content

Commit

Permalink
RelayViz interface and terminal ast-dump (#10085)
Browse files Browse the repository at this point in the history
* RelayViz interface and terminal ast-dump.

This PR follows #8668, with splitting
out interfaces class and terminal ast-dump implementation.

This visualizer is aimed for quick look-then-fix, so the interface is
simple. Despite that, customization is still possbile through
implementing interfaces defined in `interface.py` or overriding existent
implementations inside a renderer module, like `terminal.py`.

A tutorial is also provided in this PR.

A graphviz renderer will also be contributed after this PR.

* lint and typo
  • Loading branch information
chiwwang authored Feb 22, 2022
1 parent 9dd62b4 commit 55cfc4a
Show file tree
Hide file tree
Showing 5 changed files with 835 additions and 0 deletions.
10 changes: 10 additions & 0 deletions docs/reference/api/python/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ tvm.contrib.random
:members:


tvm.contrib.relay_viz
~~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.relay_viz
:members:
.. automodule:: tvm.contrib.relay_viz.interface
:members:
.. automodule:: tvm.contrib.relay_viz.terminal
:members:


tvm.contrib.rocblas
~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.contrib.rocblas
Expand Down
159 changes: 159 additions & 0 deletions gallery/how_to/work_with_relay/using_relay_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=line-too-long
"""
Use Relay Visualizer to Visualize Relay
============================================================
**Author**: `Chi-Wei Wang <https://github.com/chiwwang>`_
Relay IR module can contain lots of operations. Although an individual
operation is usually easy to understand, putting them together can cause
a complicated, hard-to-read graph. Things can get even worse with optimiztion-passes
coming into play.
This utility visualizes an IR module as nodes and edges. It defines a set of interfaces including
parser, plotter(renderer), graph, node, and edges.
A default parser is provided. Users can implement their own renderers to render the graph.
Here we use a renderer rendering graph in the text-form.
It is a lightweight, AST-like visualizer, inspired by `clang ast-dump <https://clang.llvm.org/docs/IntroductionToTheClangAST.html>`_.
We will introduce how to implement customized parsers and renderers through interface classes.
"""
from typing import (
Dict,
Union,
Tuple,
List,
)
import tvm
from tvm import relay
from tvm.contrib import relay_viz
from tvm.contrib.relay_viz.interface import (
VizEdge,
VizNode,
VizParser,
)
from tvm.contrib.relay_viz.terminal import (
TermGraph,
TermPlotter,
TermVizParser,
)

######################################################################
# Define a Relay IR Module with multiple GlobalVar
# ------------------------------------------------
# Let's build an example Relay IR Module containing multiple ``GlobalVar``.
# We define an ``add`` function and call it in the main function.
data = relay.var("data")
bias = relay.var("bias")
add_op = relay.add(data, bias)
add_func = relay.Function([data, bias], add_op)
add_gvar = relay.GlobalVar("AddFunc")

input0 = relay.var("input0")
input1 = relay.var("input1")
input2 = relay.var("input2")
add_01 = relay.Call(add_gvar, [input0, input1])
add_012 = relay.Call(add_gvar, [input2, add_01])
main_func = relay.Function([input0, input1, input2], add_012)
main_gvar = relay.GlobalVar("main")

mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func})

######################################################################
# Render the graph with Relay Visualizer on the terminal
# ------------------------------------------------------
# The terminal can show a Relay IR module in text similar to clang AST-dump.
# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function.
viz = relay_viz.RelayVisualizer(mod)
viz.render()

######################################################################
# Customize Parser for Interested Relay Types
# -------------------------------------------
# Sometimes we want to emphasize interested information, or parse things differently for a specific usage.
# It is possible to provide customized parsers as long as it obeys the interface.
# Here demostrate how to customize parsers for ``relay.var``.
# We need to implement abstract interface :py:class:`tvm.contrib.relay_viz.interface.VizParser`.
class YourAwesomeParser(VizParser):
def __init__(self):
self._delegate = TermVizParser()

def get_node_edges(
self,
node: relay.Expr,
relay_param: Dict[str, tvm.runtime.NDArray],
node_to_id: Dict[relay.Expr, str],
) -> Tuple[Union[VizNode, None], List[VizEdge]]:

if isinstance(node, relay.Var):
node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}")
# no edge is introduced. So return an empty list.
return node, []

# delegate other types to the other parser.
return self._delegate.get_node_edges(node, relay_param, node_to_id)


######################################################################
# Pass the parser and an interested renderer to visualizer.
# Here we just the terminal renderer.
viz = relay_viz.RelayVisualizer(mod, {}, TermPlotter(), YourAwesomeParser())
viz.render()

######################################################################
# Customization around Graph and Plotter
# -------------------------------------------
# Besides parsers, we can also customize graph and renderers by implementing
# abstract class :py:class:`tvm.contrib.relay_viz.interface.VizGraph` and
# :py:class:`tvm.contrib.relay_viz.interface.Plotter`.
# Here we override the ``TermGraph`` defined in ``terminal.py`` for easier demo.
# We add a hook duplicating above ``AwesomeVar``, and make ``TermPlotter`` use the new class.
class AwesomeGraph(TermGraph):
def node(self, viz_node):
# add the node first
super().node(viz_node)
# if it's AwesomeVar, duplicate it.
if viz_node.type_name == "AwesomeVar":
duplicated_id = f"duplciated_{viz_node.identity}"
duplicated_type = "double AwesomeVar"
super().node(VizNode(duplicated_id, duplicated_type, ""))
# connect the duplicated var to the original one
super().edge(VizEdge(duplicated_id, viz_node.identity))


# override TermPlotter to use `AwesomeGraph` instead
class AwesomePlotter(TermPlotter):
def create_graph(self, name):
self._name_to_graph[name] = AwesomeGraph(name)
return self._name_to_graph[name]


viz = relay_viz.RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser())
viz.render()

######################################################################
# Summary
# -------
# This tutorial demonstrates the usage of Relay Visualizer and customization.
# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces
# defined in ``interface.py``.
#
# It is aimed for quick look-then-fix iterations.
# The constructor arguments are intended to be simple, while the customization is still
# possible through a set of interface classes.
#
105 changes: 105 additions & 0 deletions python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relay IR Visualizer"""
from typing import Dict
import tvm
from tvm import relay
from .interface import (
Plotter,
VizGraph,
VizParser,
)
from .terminal import (
TermPlotter,
TermVizParser,
)


class RelayVisualizer:
"""Relay IR Visualizer
Parameters
----------
relay_mod: tvm.IRModule
Relay IR module.
relay_param: None | Dict[str, tvm.runtime.NDArray]
Relay parameter dictionary. Default `None`.
plotter: Plotter
An instance of class inheriting from Plotter interface.
Default is an instance of `terminal.TermPlotter`.
parser: VizParser
An instance of class inheriting from VizParser interface.
Default is an instance of `terminal.TermVizParser`.
"""

def __init__(
self,
relay_mod: tvm.IRModule,
relay_param: Dict[str, tvm.runtime.NDArray] = None,
plotter: Plotter = None,
parser: VizParser = None,
):
self._plotter = plotter if plotter is not None else TermPlotter()
self._relay_param = relay_param if relay_param is not None else {}
self._parser = parser if parser is not None else TermVizParser()

global_vars = relay_mod.get_global_vars()
graph_names = []
# If we have main function, put it to the first.
# Then main function can be shown on the top.
for gv_node in global_vars:
if gv_node.name_hint == "main":
graph_names.insert(0, gv_node.name_hint)
else:
graph_names.append(gv_node.name_hint)

node_to_id = {}
# callback to generate an unique string-ID for nodes.
def traverse_expr(node):
if node in node_to_id:
return
node_to_id[node] = str(len(node_to_id))

for name in graph_names:
node_to_id.clear()
relay.analysis.post_order_visit(relay_mod[name], traverse_expr)
graph = self._plotter.create_graph(name)
self._add_nodes(graph, node_to_id)

def _add_nodes(self, graph: VizGraph, node_to_id: Dict[relay.Expr, str]):
"""add nodes and to the graph.
Parameters
----------
graph : VizGraph
a VizGraph for nodes to be added to.
node_to_id : Dict[relay.expr, str]
a mapping from nodes to an unique ID.
relay_param : Dict[str, tvm.runtime.NDarray]
relay parameter dictionary.
"""
for node in node_to_id:
viz_node, viz_edges = self._parser.get_node_edges(node, self._relay_param, node_to_id)
if viz_node is not None:
graph.node(viz_node)
for edge in viz_edges:
graph.edge(edge)

def render(self, filename: str = None) -> None:
self._plotter.render(filename=filename)
Loading

0 comments on commit 55cfc4a

Please sign in to comment.