1414from  vllm .config  import  CUDAGraphMode , ParallelConfig , VllmConfig 
1515from  vllm .logger  import  init_logger 
1616from  vllm .platforms  import  current_platform 
17+ from  vllm .v1 .worker .ubatch_utils  import  UBatchSlices , is_second_ubatch_empty 
1718
1819if  TYPE_CHECKING :
1920    from  vllm .attention .backends .abstract  import  AttentionMetadata 
@@ -97,6 +98,53 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
9798        dist .all_reduce (num_tokens_tensor , group = group )
9899        return  num_tokens_tensor .cpu ()
99100
101+     @staticmethod  
102+     def  should_ubatch_across_dp (
103+             should_ubatch : bool , orig_num_tokens_per_ubatch : int ,
104+             padded_num_tokens_per_ubatch : int , dp_size : int ,
105+             dp_rank : int ) ->  tuple [bool , Optional [torch .Tensor ]]:
106+         """ 
107+         1. Decides if each DP rank is going to microbatch. Either all ranks 
108+         run with microbatching or none of them do. If this function decides 
109+         not to run with microbatching. It will "abort" meaning that no padding 
110+         information will be returned to the caller. It will return (False, None) 
111+ 
112+         2. Determines the total number of tokens that each rank will run. 
113+         All ranks will be padded out so that the run with the same number 
114+         of tokens 
115+ 
116+         Returns: tuple[ 
117+             should_ubatch: Are all DP ranks going to microbatch 
118+             num_tokens_after_padding: A tensor containing the total number of 
119+             tokens per-microbatch for each DP rank including padding. Will be 
120+             None if should_ubatch if False 
121+         ] 
122+         """ 
123+ 
124+         device  =  current_platform .device_type 
125+         tensor  =  torch .zeros (3 , dp_size , device = device , dtype = torch .int32 )
126+         tensor [0 ][dp_rank ] =  orig_num_tokens_per_ubatch 
127+         tensor [1 ][dp_rank ] =  padded_num_tokens_per_ubatch 
128+         tensor [2 ][dp_rank ] =  1  if  should_ubatch  else  0 
129+ 
130+         from  vllm .distributed .parallel_state  import  get_dp_group 
131+         dist .all_reduce (tensor , group = get_dp_group ().device_group )
132+ 
133+         result : bool  =  bool (torch .all (tensor [2 ] ==  1 ).item ())
134+         if  not  result :
135+             return  result , None 
136+ 
137+         orig_num_tokens_tensor  =  tensor [0 , :]
138+         padded_num_tokens_tensor  =  tensor [1 , :]
139+ 
140+         orig_min_num_tokens  =  int (orig_num_tokens_tensor .min ().item ())
141+         padded_max_num_tokens  =  int (padded_num_tokens_tensor .max ().item ())
142+         if  is_second_ubatch_empty (orig_min_num_tokens , padded_max_num_tokens ):
143+             logger .debug ("Aborting ubatching %s %s" , orig_min_num_tokens ,
144+                          padded_max_num_tokens )
145+             return  False , None 
146+         return  result , padded_num_tokens_tensor .cpu ()
147+ 
100148    @staticmethod  
101149    def  make (
102150            parallel_config : ParallelConfig ,
@@ -119,14 +167,15 @@ def make(
119167
120168        # If num_tokens_across_dp is None, it will be computed by all_reduce 
121169        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize 
122-         assert  (num_tokens_across_dp  is  None 
123-                 or   num_tokens_across_dp [dp_rank ]  ==   batchsize ) 
170+         assert  (num_tokens_across_dp  is  None   or   num_tokens_across_dp [ dp_rank ] 
171+                 ==   batchsize ),  f" { num_tokens_across_dp [dp_rank ]}   { batchsize } " 
124172        if  num_tokens_across_dp  is  None :
125173            num_tokens_across_dp  =  DPMetadata .num_tokens_across_dp (
126174                batchsize , dp_size , dp_rank )
127175        max_tokens_across_dp_cpu  =  torch .max (num_tokens_across_dp )
128176        cu_tokens_across_dp_cpu  =  torch .cumsum (num_tokens_across_dp , dim = 0 )
129-         return  DPMetadata (max_tokens_across_dp_cpu , cu_tokens_across_dp_cpu )
177+         return  DPMetadata (max_tokens_across_dp_cpu , cu_tokens_across_dp_cpu ,
178+                           num_tokens_across_dp )
130179
131180    @contextmanager  
132181    def  chunked_sizes (self , max_chunk_size_per_rank : int , chunk_idx : int ):
@@ -179,9 +228,12 @@ class ForwardContext:
179228    Type AttentionMetadata for v0,  
180229    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each  
181230    attention layer to its attention metadata 
182-     set dynamically for each forward pass 
231+     Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one 
232+     for each microbatch. 
233+     Set dynamically for each forward pass 
183234    """ 
184-     attn_metadata : Union ["AttentionMetadata" , dict [str , "AttentionMetadata" ]]
235+     attn_metadata : Union ["AttentionMetadata" , dict [str , "AttentionMetadata" ],
236+                          list [dict [str , "AttentionMetadata" ]]]
185237    # TODO: remove after making all virtual_engines share the same kv cache 
186238    virtual_engine : int   # set dynamically for each forward pass 
187239    # set dynamically for each forward pass 
@@ -191,6 +243,8 @@ class ForwardContext:
191243    cudagraph_runtime_mode : CUDAGraphMode  =  CUDAGraphMode .NONE 
192244    batch_descriptor : Optional [BatchDescriptor ] =  None 
193245
246+     ubatch_slices : Optional [UBatchSlices ] =  None 
247+ 
194248    def  __post_init__ (self ):
195249        assert  self .cudagraph_runtime_mode  in  [
196250            CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL ], \
@@ -208,6 +262,39 @@ def get_forward_context() -> ForwardContext:
208262    return  _forward_context 
209263
210264
265+ def  create_forward_context (
266+         attn_metadata : Any ,
267+         vllm_config : VllmConfig ,
268+         virtual_engine : int  =  0 ,
269+         dp_metadata : Optional [DPMetadata ] =  None ,
270+         cudagraph_runtime_mode : CUDAGraphMode  =  CUDAGraphMode .NONE ,
271+         batch_descriptor : Optional [BatchDescriptor ] =  None ,
272+         ubatch_slices : Optional [UBatchSlices ] =  None ):
273+     return  ForwardContext (no_compile_layers = vllm_config .compilation_config .
274+                           static_forward_context ,
275+                           virtual_engine = virtual_engine ,
276+                           attn_metadata = attn_metadata ,
277+                           dp_metadata = dp_metadata ,
278+                           cudagraph_runtime_mode = cudagraph_runtime_mode ,
279+                           batch_descriptor = batch_descriptor ,
280+                           ubatch_slices = ubatch_slices )
281+ 
282+ 
283+ @contextmanager  
284+ def  override_forward_context (forward_context : Optional [ForwardContext ]):
285+     """A context manager that overrides the current forward context. 
286+     This is used to override the forward context for a specific 
287+     forward pass. 
288+     """ 
289+     global  _forward_context 
290+     prev_context  =  _forward_context 
291+     _forward_context  =  forward_context 
292+     try :
293+         yield 
294+     finally :
295+         _forward_context  =  prev_context 
296+ 
297+ 
211298@contextmanager  
212299def  set_forward_context (
213300        attn_metadata : Any ,
@@ -216,7 +303,8 @@ def set_forward_context(
216303        num_tokens : Optional [int ] =  None ,
217304        num_tokens_across_dp : Optional [torch .Tensor ] =  None ,
218305        cudagraph_runtime_mode : CUDAGraphMode  =  CUDAGraphMode .NONE ,
219-         batch_descriptor : Optional [BatchDescriptor ] =  None ):
306+         batch_descriptor : Optional [BatchDescriptor ] =  None ,
307+         ubatch_slices : Optional [UBatchSlices ] =  None ):
220308    """A context manager that stores the current forward context, 
221309    can be attention metadata, etc. 
222310    Here we can inject common logic for every model forward pass. 
@@ -225,27 +313,22 @@ def set_forward_context(
225313    need_to_track_batchsize  =  track_batchsize  and  attn_metadata  is  not None 
226314    if  need_to_track_batchsize :
227315        forward_start_time  =  time .perf_counter ()
316+ 
228317    dp_metadata : Optional [DPMetadata ] =  None 
229318    if  vllm_config .parallel_config .data_parallel_size  >  1  and  (
230319            attn_metadata  is  not None  or  num_tokens  is  not None ):
231320        dp_metadata  =  DPMetadata .make (vllm_config .parallel_config ,
232321                                      attn_metadata , num_tokens  or  0 ,
233322                                      num_tokens_across_dp )
234323
235-     global  _forward_context 
236-     prev_context  =  _forward_context 
237-     _forward_context  =  ForwardContext (
238-         no_compile_layers = vllm_config .compilation_config .
239-         static_forward_context ,
240-         virtual_engine = virtual_engine ,
241-         attn_metadata = attn_metadata ,
242-         dp_metadata = dp_metadata ,
243-         cudagraph_runtime_mode = cudagraph_runtime_mode ,
244-         batch_descriptor = batch_descriptor ,
245-     )
324+     forward_context  =  create_forward_context (attn_metadata , vllm_config ,
325+                                              virtual_engine , dp_metadata ,
326+                                              cudagraph_runtime_mode ,
327+                                              batch_descriptor , ubatch_slices )
246328
247329    try :
248-         yield 
330+         with  override_forward_context (forward_context ):
331+             yield 
249332    finally :
250333        global  last_logging_time , batchsize_logging_interval 
251334        if  need_to_track_batchsize :
@@ -282,5 +365,3 @@ def set_forward_context(
282365                    logger .info (("Batchsize forward time stats " 
283366                                 "(batchsize, count, median_time(ms)): %s" ),
284367                                forward_stats )
285- 
286-         _forward_context  =  prev_context 
0 commit comments