88import  re 
99import  logging 
1010import  textwrap 
11+ from  tvm .tir .stmt_functor  import  post_order_visit 
1112
1213PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY  =  """ 
1314    cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1}); 
@@ -260,7 +261,11 @@ def create_dispatch_func(self, code, function_informations):
260261        # Format the function arguments for declaration 
261262        def_args  =  ", " .join ([f"{ arg ['type' ]} { arg ['name' ]}   for  arg  in  function_args ])
262263
263-         def  func_call_args (s , function_args , desc_name_map : Optional [Dict [str , str ]] =  None ):
264+         def  func_call_args (s ,
265+                            function_args ,
266+                            function_params ,
267+                            desc_name_map : Optional [Dict [str , str ]] =  None ,
268+                            desc_name_var_map : Optional [Dict [str , tvm .tir .Var ]] =  None ):
264269            # Extract the function call arguments matching the function definition 
265270            def  maybe_desc (name : str , matches : List [str ], i : int ):
266271                match  =  matches [i ]
@@ -280,8 +285,15 @@ def maybe_desc(name: str, matches: List[str], i: int):
280285            call_args  =  []
281286            for  i , match  in  enumerate (matches ):
282287                for  arg  in  function_args :
283-                     if  arg ["name" ] ==  match   or   maybe_desc ( arg [ "name" ],  matches ,  i ) :
288+                     if  arg ["name" ] ==  match :
284289                        call_args .append (match )
290+                     elif  maybe_desc (arg ["name" ], matches , i ):
291+                         call_args .append (match )
292+                         assert  len (call_args ) <=  len (
293+                             function_params 
294+                         ), f"Function { function_name } { len (function_params )} { len (call_args )}  
295+                         desc_name_var_map [match ] =  function_params [len (call_args ) -  1 ]
296+ 
285297            return  call_args 
286298
287299        has_l2_persistent_map  =  False 
@@ -294,10 +306,12 @@ def maybe_desc(name: str, matches: List[str], i: int):
294306        if  has_l2_persistent_map :
295307            kernel_launch_code  +=  L2_PERSISTENT_MAP_CREATE_HANDLE 
296308        desc_name_map : Dict [str , str ] =  {}
309+         desc_name_var_map : Dict [str , tvm .tir .Var ] =  {}
297310        for  function_name , function_info  in  function_informations .items ():
298311            block_info  =  function_info ["block_info" ]
299312            grid_info  =  function_info ["grid_info" ]
300313            dynamic_smem_buf  =  function_info ["dynamic_smem_buf" ]
314+             function_params  =  function_info ["function_params" ]
301315
302316            # Find the location of the global kernel function in the code 
303317            index  =  match_declare_kernel (code , function_name  +  "(" )
@@ -321,22 +335,32 @@ def maybe_desc(name: str, matches: List[str], i: int):
321335            kernel_launch_code  +=  init_l2_persistent_map 
322336
323337            if  self .use_cooperative_groups [function_name ]:
324-                 args_list  =  func_call_args (declaration , function_args , desc_name_map )
338+                 args_list  =  func_call_args (declaration , function_args , function_params ,
339+                                            desc_name_map , desc_name_var_map )
340+                 assert  len (function_params ) ==  len (
341+                     args_list 
342+                 ), f"Function { function_name } { len (function_params )} { len (args_list )}  
325343                args_array  =  [f"(void*)&{ arg }   for  arg  in  args_list ]
326344                call_args  =  f"\t void* { function_name } { ', ' .join (args_array )} \n " 
327345                kernel_launch_code  +=  call_args 
328346                # Using cudaLaunchCooperativeKernel to launch the kernel 
329347                kernel_launch_code  +=  "\t TILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n " .format (
330348                    function_name , grid_str , block_str , function_name  +  "_args" , smem_str )
331349            else :
332-                 call_args  =  ", " .join (func_call_args (declaration , function_args , desc_name_map ))
350+                 args_list  =  func_call_args (declaration , function_args , function_params ,
351+                                            desc_name_map , desc_name_var_map )
352+                 assert  len (function_params ) ==  len (
353+                     args_list 
354+                 ), f"Function { function_name } { len (function_params )} { len (args_list )}  
355+                 call_args  =  ", " .join (args_list )
333356                kernel_launch_code  +=  "\t {}<<<{}, {}, {}, stream>>>({});\n " .format (
334357                    function_name , grid_str , block_str , smem_str , call_args )
335358                kernel_launch_code  +=  "\t TILELANG_CHECK_LAST_ERROR(\" {}\" );\n " .format (function_name )
336359            if  has_l2_persistent_map :
337360                kernel_launch_code  +=  L2_PERSISTENT_MAP_RESET_HANDLE 
338361
339-         init_tma_descriptor_args  =  self .generate_tma_descriptor_args (desc_name_map )
362+         init_tma_descriptor_args  =  self .generate_tma_descriptor_args (desc_name_map ,
363+                                                                      desc_name_var_map )
340364        kernel_launch_code  =  init_tma_descriptor_args  +  kernel_launch_code 
341365
342366        # Wrap the kernel dispatch logic in an external C function 
@@ -362,15 +386,17 @@ def generate_l2_persistent_map(self, function_name: str) -> str:
362386
363387        return  init_l2_persistent_map 
364388
365-     def  generate_tma_descriptor_args (self , desc_name_map : Dict [str , str ]) ->  str :
389+     def  generate_tma_descriptor_args (self , desc_name_map : Dict [str , str ],
390+                                      desc_name_var_map : Dict [str , tvm .tir .Var ]) ->  str :
366391        tma_descripter_init  =  "" 
367392        if  self .tma_descriptor_args  is  None :
368393            return  tma_descripter_init 
394+         for  handle_name , _  in  desc_name_map .items ():
395+             assert  handle_name  in  desc_name_var_map , f"Handle name { handle_name }  
396+             desc_var  =  desc_name_var_map [handle_name ]
369397
370-         for  handle_name , name  in  desc_name_map .items ():
371-             desc_name  =  name  +  "_desc" 
372-             assert  desc_name  in  self .tma_descriptor_args , f"TMA descriptor { desc_name } { self .tma_descriptor_args }  
373-             args  =  self .tma_descriptor_args [desc_name ]
398+             assert  desc_var  in  self .tma_descriptor_args , f"TMA descriptor { desc_var } { self .tma_descriptor_args }  
399+             args  =  self .tma_descriptor_args [desc_var ]
374400            # Skip __tvm_tensormap_create_tiled 
375401            if  len (args ) <  3 :
376402                raise  ValueError (
@@ -536,12 +562,35 @@ def update_lib_code(self, code: str):
536562            # Do not update function with dispatch host function 
537563            if  (function_name  not  in self .block_info ) or  (function_name  not  in self .grid_info ):
538564                continue 
565+             assert  function_name  in  self .device_mod , f"Function { function_name }  
566+             device_func  =  self .device_mod [function_name ]
567+             kernel_params_cnt  =  len (device_func .params )
568+             function_params : List [str ] =  None 
569+ 
570+             def  visitor (node , fn = function_name , param_cnt = kernel_params_cnt ):
571+                 nonlocal  function_params 
572+                 if  isinstance (node , tvm .tir .Call ):
573+                     if  not  (hasattr (node , "op" ) and 
574+                             node .op  ==  tvm .ir .Op .get ("tir.tvm_call_packed" )):
575+                         return 
576+                     args  =  node .args 
577+                     if  not  args  or  args [0 ] !=  fn :
578+                         return 
579+                     if  len (args ) <  1  +  param_cnt :
580+                         raise  AssertionError (
581+                             "tvm_call_packed should have at least 1 argument and match device function parameters" 
582+                         )
583+                     function_params  =  args [1 :1  +  param_cnt ]
584+ 
585+             post_order_visit (self .host_func .body , visitor )
586+             assert  function_params  is  not None , "function_params should not be None" 
539587
540588            function_informations [function_name ] =  {
541589                "function_name" : function_name ,
542590                "block_info" : self .block_info [function_name ],
543591                "grid_info" : self .grid_info [function_name ],
544592                "dynamic_smem_buf" : self .dynamic_smem_buf [function_name ],
593+                 "function_params" : function_params ,
545594            }
546595
547596        # Create the host function wrapper for the CUDA kernel 
@@ -579,6 +628,19 @@ def device_func(self):
579628                    return  function 
580629            raise  ValueError ("Cannot find primary function in the module." )
581630
631+     @property  
632+     def  host_func (self ):
633+         if  len (self .host_mod .get_global_vars ()) ==  1 :
634+             return  self .host_mod [self .host_mod .get_global_vars ()[0 ]]
635+         elif  "main"  in  self .host_mod :
636+             return  self .host_mod ["main" ]
637+         else :
638+             for  _ , function  in  self .host_mod .functions .items ():
639+                 attr  =  function .attrs 
640+                 if  "tir.is_global_func"  in  attr  and  attr ["tir.is_global_func" ]:
641+                     return  function 
642+             raise  ValueError ("Cannot find primary function in the module." )
643+ 
582644
583645class  TLNVRTCSourceWrapper (TLCUDASourceWrapper ):
584646    """ 
@@ -636,7 +698,6 @@ def create_dispatch_func(self, code, function_informations):
636698                function_args .append ({"name" : dyn_sym , "type" : "ctypes.c_int" })
637699
638700        function_args .append (self .get_stream_type ())
639- 
640701        # Format the function arguments for declaration 
641702        def_args  =  ", " .join ([f"{ arg ['name' ]}   for  arg  in  function_args ])
642703
0 commit comments