@@ -857,6 +857,7 @@ def __init__(self):
857857 def dense_fixed (name : str , length : PrimExpr , span : Optional [Span ] = None ):
858858 var_name = self .node .lhs [0 ].id .name
859859 axis = DenseFixedAxis (name , length )
860+ self .context .sp_struct2param_map [axis ] = []
860861 self .context .update_symbol (var_name , axis , self .node )
861862
862863 super ().__init__ (dense_fixed , def_symbol = True )
@@ -880,7 +881,7 @@ def dense_variable(
880881 (indptr_len ,), dtype = idtype , name = name + "_indptr" , span = span
881882 )
882883 axis = DenseVariableAxis (name , length , indptr_buf )
883- self .context .func_buffer_map [ indptr_var ] = indptr_buf
884+ self .context .sp_struct2param_map [ axis ] = indptr_var
884885 self .context .update_symbol (var_name , axis , self .node )
885886 self .context .update_symbol (name + "_indptr" , indptr_buf , self .node )
886887
@@ -905,7 +906,7 @@ def sparse_fixed(
905906 (nnz ,), dtype = idtype , name = name + "_indices" , span = span
906907 )
907908 axis = SparseFixedAxis (name , length , indices_buf , nnz_cols )
908- self .context .func_buffer_map [ indices_var ] = indices_buf
909+ self .context .sp_struct2param_map [ axis ] = [ indices_var ]
909910 self .context .update_symbol (var_name , axis , self .node )
910911 self .context .update_symbol (name + "_indices" , indices_buf , self .node )
911912
@@ -934,8 +935,7 @@ def sparse_variable(
934935 (nnz ,), dtype = idtype , name = name + "_indices" , span = span
935936 )
936937 axis = SparseVariableAxis (name , length , indptr_buf , indices_buf )
937- self .context .func_buffer_map [indices_var ] = indices_buf
938- self .context .func_buffer_map [indptr_var ] = indptr_buf
938+ self .context .sp_struct2param_map [axis ] = [indptr_var , indices_var ]
939939 self .context .update_symbol (var_name , axis , self .node )
940940 self .context .update_symbol (name + "_indptr" , indptr_buf , self .node )
941941 self .context .update_symbol (name + "_indices" , indices_buf , self .node )
@@ -971,8 +971,7 @@ def match_sparse_buffer(
971971 if param in self .context .func_params :
972972 data = tvm .tir .decl_buffer (nnz , dtype , buffer_name + "_data" , span = span )
973973 buffer = tvm .tir .sparse .SparseBuffer (axes , data , buffer_name )
974- self .context .func_buffer_map [param ] = data
975- self .context .func_sparse_buffer_map [param ] = buffer
974+ self .context .sp_struct2param_map [buffer ] = [param ]
976975 self .context .update_symbol (buffer_name + "_data" , data , self .node )
977976 self .context .update_symbol (buffer_name , buffer , self .node )
978977 else :
0 commit comments