3232from ..op .memory import flatten_tuple_type , from_tuple_type , to_tuple_type
3333from ...import cpu
3434from ..op .memory import alloc_storage
35- from ..analysis import context_analysis as _context_analysis
35+ from ..analysis import context_analysis
3636from ..._ffi .runtime_ctypes import TVMContext
3737
3838def alloc_tensor (storage , shape , dtype = 'float32' , assert_shape = None ):
@@ -85,7 +85,7 @@ def is_reshape_only(func):
8585class ManifestAllocPass (ExprMutator ):
8686 """A pass for explicitly manifesting all memory allocations in Relay."""
8787
88- def __init__ (self , target_host , context_analysis ):
88+ def __init__ (self , target_host , context_analysis_map ):
8989 self .invoke_tvm = op .vm .invoke_tvm_op
9090 self .shape_func = op .vm .shape_func
9191 self .shape_of = op .vm .shape_of
@@ -94,13 +94,13 @@ def __init__(self, target_host, context_analysis):
9494 self .target_host = target_host
9595 self .default_context = cpu (0 )
9696 self .compute_dtype = "int64"
97- self .context_analysis = context_analysis
97+ self .context_analysis_map = context_analysis_map
9898 super ().__init__ ()
9999
100100 def get_context (self , exp ):
101101 """Get the context of a given expression"""
102- assert exp in self .context_analysis , exp .astext (False )
103- val = self .context_analysis [exp ]
102+ assert exp in self .context_analysis_map , exp .astext (False )
103+ val = self .context_analysis_map [exp ]
104104 # val[0], val[1] are device_type and device_id, respectively.
105105 # We don't need to unpack after porting this pass to C++.
106106 assert len (val ) == 2
@@ -339,6 +339,7 @@ def _annotator(exp):
339339@module_pass (opt_level = 0 )
340340class ManifestAlloc :
341341 """The explicit pass wrapper around ManifestAlloc."""
342+ # TODO(zhiics, jroesch) Port this pass to C++.
342343 def __init__ (self , target_host , targets ):
343344 self .target_host = target_host
344345 self .targets = targets
@@ -356,13 +357,13 @@ def transform_module(self, mod, _):
356357 fallback_ctx = nd .context (pass_ctx .config ["relay.fallback_device_type" ])
357358 else :
358359 fallback_ctx = cpu (0 )
359- ca = _context_analysis (mod , TVMContext (fallback_ctx .device_type , 0 ))
360+ ca = context_analysis (mod , TVMContext (fallback_ctx .device_type , 0 ))
360361 else :
361362 if isinstance (self .targets , dict ):
362363 dev = list (self .targets .keys ())[0 ]
363364 else :
364365 dev , _ = self .targets .items ()[0 ]
365- ca = _context_analysis (mod , nd .context (dev .value ))
366+ ca = context_analysis (mod , nd .context (dev .value ))
366367
367368 # The following code can be used for debugging the module after
368369 # annotation.
0 commit comments