Skip to content

Commit 6ae6900

Browse files
committed
fix comments and add doc for context analysis
1 parent fbaee58 commit 6ae6900

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

python/tvm/relay/transform/memory_alloc.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
3333
from ...import cpu
3434
from ..op.memory import alloc_storage
35-
from ..analysis import context_analysis as _context_analysis
35+
from ..analysis import context_analysis
3636
from ..._ffi.runtime_ctypes import TVMContext
3737

3838
def alloc_tensor(storage, shape, dtype='float32', assert_shape=None):
@@ -85,7 +85,7 @@ def is_reshape_only(func):
8585
class ManifestAllocPass(ExprMutator):
8686
"""A pass for explicitly manifesting all memory allocations in Relay."""
8787

88-
def __init__(self, target_host, context_analysis):
88+
def __init__(self, target_host, context_analysis_map):
8989
self.invoke_tvm = op.vm.invoke_tvm_op
9090
self.shape_func = op.vm.shape_func
9191
self.shape_of = op.vm.shape_of
@@ -94,13 +94,13 @@ def __init__(self, target_host, context_analysis):
9494
self.target_host = target_host
9595
self.default_context = cpu(0)
9696
self.compute_dtype = "int64"
97-
self.context_analysis = context_analysis
97+
self.context_analysis_map = context_analysis_map
9898
super().__init__()
9999

100100
def get_context(self, exp):
101101
"""Get the context of a given expression"""
102-
assert exp in self.context_analysis, exp.astext(False)
103-
val = self.context_analysis[exp]
102+
assert exp in self.context_analysis_map, exp.astext(False)
103+
val = self.context_analysis_map[exp]
104104
# val[0], val[1] are device_type and device_id, respectively.
105105
# We don't need to unpack after porting this pass to C++.
106106
assert len(val) == 2
@@ -339,6 +339,7 @@ def _annotator(exp):
339339
@module_pass(opt_level=0)
340340
class ManifestAlloc:
341341
"""The explicit pass wrapper around ManifestAlloc."""
342+
# TODO(zhiics, jroesch) Port this pass to C++.
342343
def __init__(self, target_host, targets):
343344
self.target_host = target_host
344345
self.targets = targets
@@ -356,13 +357,13 @@ def transform_module(self, mod, _):
356357
fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"])
357358
else:
358359
fallback_ctx = cpu(0)
359-
ca = _context_analysis(mod, TVMContext(fallback_ctx.device_type, 0))
360+
ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0))
360361
else:
361362
if isinstance(self.targets, dict):
362363
dev = list(self.targets.keys())[0]
363364
else:
364365
dev, _ = self.targets.items()[0]
365-
ca = _context_analysis(mod, nd.context(dev.value))
366+
ca = context_analysis(mod, nd.context(dev.value))
366367

367368
# The following code can be used for debugging the module after
368369
# annotation.

src/relay/analysis/context_analysis.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,33 @@
2020
/*!
2121
* \file src/relay/analysis/context_analysis.cc
2222
* \brief A pass for analyzing device attribute of each IR node.
23+
*
24+
* We use union-find data structures to analyze the context information of each
25+
* sub-expression in a Relay program in this pass. Only the device copy node in
26+
* Relay directly contains bidiretional device information. We use it to
27+
* bidirectionally propagate the device info of its inputs and outputs.
28+
*
29+
* However, to support dynamism (e.g dynamic inputs), Relay introduces several
30+
* concepts to compute the shape of tensors and operators at runtime, i.e.
31+
* shape_of, shape_func, and reshape_tensor. These nodes are also referred to as
32+
* VM dialects as we have native VM instructions for them. These dialects are
33+
* intrinsically CPU friendly, therefore, they are only designed to be
34+
* executed on CPU. We, hence, unify their inputs and outputs to CPU as well.
35+
* Note the input of shape_of is a tensor and we only need the tensor shape.
36+
* Therefore, the input could be sitting on GPU as well since no real data is
37+
* needed. The context of the input would be propagated from its other
38+
* consumers or fallback to the default device.
39+
*
40+
* Another type of dialect is used fo memory allocation, namely, alloc_storage
41+
* and alloc_tensor. alloc_storage contains a context field to indicate where
42+
* the chunk of memory is allocated. Therefore, we unify the context of
43+
* alloc_storage with the context field. Other inputs, such as size and
44+
* alignment, are left on CPU.
45+
*
46+
* Based on the above rules, we keep unifying the connected expressions and
47+
* propagating their device information. An error will be raised whenever there
48+
* is a unification conflict. All IR nodes that are not propagated with device
49+
* context will fallback to the specified device.
2350
*/
2451

2552
#include <tvm/relay/analysis.h>

0 commit comments

Comments
 (0)