@@ -903,6 +903,7 @@ def __init__(self):
903903 def dense_fixed (name : str , length : PrimExpr , span : Optional [Span ] = None ):
904904 var_name = self .node .lhs [0 ].id .name
905905 axis = DenseFixedAxis (name , length )
906+ self .context .sp_struct2param_map [axis ] = []
906907 self .context .update_symbol (var_name , axis , self .node )
907908
908909 super ().__init__ (dense_fixed , def_symbol = True )
@@ -926,7 +927,7 @@ def dense_variable(
926927 (indptr_len ,), dtype = idtype , name = name + "_indptr" , span = span
927928 )
928929 axis = DenseVariableAxis (name , length , indptr_buf )
929- self .context .func_buffer_map [ indptr_var ] = indptr_buf
930+ self .context .sp_struct2param_map [ axis ] = indptr_var
930931 self .context .update_symbol (var_name , axis , self .node )
931932 self .context .update_symbol (name + "_indptr" , indptr_buf , self .node )
932933
@@ -951,7 +952,7 @@ def sparse_fixed(
951952 (nnz ,), dtype = idtype , name = name + "_indices" , span = span
952953 )
953954 axis = SparseFixedAxis (name , length , indices_buf , nnz_cols )
954- self .context .func_buffer_map [ indices_var ] = indices_buf
955+ self .context .sp_struct2param_map [ axis ] = [ indices_var ]
955956 self .context .update_symbol (var_name , axis , self .node )
956957 self .context .update_symbol (name + "_indices" , indices_buf , self .node )
957958
@@ -980,8 +981,7 @@ def sparse_variable(
980981 (nnz ,), dtype = idtype , name = name + "_indices" , span = span
981982 )
982983 axis = SparseVariableAxis (name , length , indptr_buf , indices_buf )
983- self .context .func_buffer_map [indices_var ] = indices_buf
984- self .context .func_buffer_map [indptr_var ] = indptr_buf
984+ self .context .sp_struct2param_map [axis ] = [indptr_var , indices_var ]
985985 self .context .update_symbol (var_name , axis , self .node )
986986 self .context .update_symbol (name + "_indptr" , indptr_buf , self .node )
987987 self .context .update_symbol (name + "_indices" , indices_buf , self .node )
@@ -1017,8 +1017,7 @@ def match_sparse_buffer(
10171017 if param in self .context .func_params :
10181018 data = tvm .tir .decl_buffer (nnz , dtype , buffer_name + "_data" , span = span )
10191019 buffer = tvm .tir .sparse .SparseBuffer (axes , data , buffer_name )
1020- self .context .func_buffer_map [param ] = data
1021- self .context .func_sparse_buffer_map [param ] = buffer
1020+ self .context .sp_struct2param_map [buffer ] = [param ]
10221021 self .context .update_symbol (buffer_name + "_data" , data , self .node )
10231022 self .context .update_symbol (buffer_name , buffer , self .node )
10241023 else :
0 commit comments