2929from .. import expr as _expr
3030from .. import op as _op
3131from .. import ty as _ty
32- from ..expr_functor import ExprFunctor
32+ from ..expr_functor import ExprVisitor
3333from . import _backend
3434
3535@register_relay_node
3636class CachedFunc (NodeBase ):
3737 """Low-level tensor function to back a relay primitive function.
3838 """
3939 def __init__ (self , target , func_name , inputs , outputs , schedule = None ,
40- lowered_funcs = [], shape_func_param_states = []):
40+ lowered_funcs = None , shape_func_param_states = None ):
41+ if lowered_funcs is None :
42+ lowered_funcs = []
43+ if shape_func_param_states is None :
44+ shape_func_param_states = []
4145 self .__init_handle_by_constructor__ (
4246 _backend ._make_CachedFunc , target , func_name , inputs , outputs ,
4347 schedule , lowered_funcs , shape_func_param_states )
@@ -79,6 +83,7 @@ def _get_cache_key(source_func, target):
7983
8084
8185def get_shape (shape ):
86+ """Convert the shape to correct dtype and vars."""
8287 ret = []
8388 for dim in shape :
8489 if isinstance (dim , tvm .expr .IntImm ):
@@ -92,7 +97,9 @@ def get_shape(shape):
9297 return ret
9398
9499
95- class ScheduleGetter (ExprFunctor ):
100+ class ScheduleGetter (ExprVisitor ):
101+ """Get the schedule given a fused Relay function"""
102+
96103 MAX_FUNC_NAME_LENGTH = 80
97104
98105 def __init__ (self , target ):
0 commit comments