@@ -2828,7 +2828,7 @@ def _get_mm_dummy_batch(
28282828    def  _dummy_run (
28292829        self ,
28302830        num_tokens : int ,
2831-         cudagraph_runtime_mode : CUDAGraphMode  =  CUDAGraphMode . NONE ,
2831+         cudagraph_runtime_mode : Optional [ CUDAGraphMode ]  =  None ,
28322832        force_attention : bool  =  False ,
28332833        uniform_decode : bool  =  False ,
28342834        allow_microbatching : bool  =  True ,
@@ -2844,6 +2844,8 @@ def _dummy_run(
28442844        Args: 
28452845            num_tokens: Number of tokens to run the dummy forward pass. 
28462846            cudagraph_runtime_mode: used to control the behavior. 
2847+                 - if not set will determine the cudagraph mode based on using  
2848+                     the self.cudagraph_dispatcher. 
28472849                - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run 
28482850                - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. 
28492851                - CUDAGraphMode.FULL: Full cudagraph, attention metadata is 
@@ -2857,7 +2859,7 @@ def _dummy_run(
28572859                (1 token) and prefill (multiple tokens) requests. 
28582860            remove_lora: If False, dummy LoRAs are not destroyed after the run 
28592861        """ 
2860-         assert  cudagraph_runtime_mode  in  {
2862+         assert  cudagraph_runtime_mode  is   None   or   cudagraph_runtime_mode   in  {
28612863            CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL 
28622864        }
28632865
@@ -2899,10 +2901,6 @@ def _dummy_run(
28992901        elif  uniform_decode :
29002902            assert  not  create_mixed_batch 
29012903            num_reqs  =  cdiv (num_tokens , max_query_len )
2902-             assert  num_reqs  <=  max_num_reqs , \
2903-                 f"Do not capture num_reqs { num_reqs }   \
2904-                 f"{ max_num_reqs }   \
2905-                 f"{ num_tokens } { max_query_len }  
29062904            num_scheduled_tokens_list  =  [max_query_len ] *  num_reqs 
29072905            if  num_tokens  %  max_query_len  !=  0 :
29082906                num_scheduled_tokens_list [- 1 ] =  num_tokens  %  max_query_len 
@@ -3043,18 +3041,20 @@ def _dummy_run(
30433041
30443042                intermediate_tensors  =  self .sync_and_slice_intermediate_tensors (
30453043                    num_tokens , None , False )
3046-              if   cudagraph_runtime_mode   ==   CUDAGraphMode . NONE : 
3047-                  batch_descriptor   =   None 
3048-             else : 
3049-                 # filter out the valid batch descriptor 
3050-                 _cg_mode ,  batch_descriptor   =  \ 
3051-                      self . cudagraph_dispatcher . dispatch ( 
3052-                          BatchDescriptor ( num_tokens = num_tokens , 
3053-                                          uniform_decode = uniform_decode )) 
3054-                 # sanity check 
3055-                 assert  cudagraph_runtime_mode  ==  _cg_mode , (
3044+ 
3045+             # filter out the valid batch descriptor 
3046+             _cg_mode ,  batch_descriptor   =   self . cudagraph_dispatcher . dispatch ( 
3047+                 BatchDescriptor ( num_tokens = num_tokens , 
3048+                                  uniform_decode = uniform_decode )) 
3049+             if   cudagraph_runtime_mode   is   not   None : 
3050+                 # we allow forcing NONE when the dispatcher disagrees to support 
3051+                 # warm ups for cudagraph capture 
3052+                 assert   cudagraph_runtime_mode   ==   CUDAGraphMode . NONE   or  \ 
3053+                      cudagraph_runtime_mode  ==  _cg_mode , (
30563054                    f"Cudagraph runtime mode mismatch at dummy_run. " 
30573055                    f"Expected { _cg_mode } { cudagraph_runtime_mode }  )
3056+             else :
3057+                 cudagraph_runtime_mode  =  _cg_mode 
30583058
30593059            if  ubatch_slices  is  not None :
30603060                num_tokens  =  num_tokens  //  2 
0 commit comments