@@ -27,36 +27,34 @@ def num_tokens(self) -> int:
2727UBatchSlices : TypeAlias = list [UBatchSlice ]
2828
2929
30- def is_second_ubatch_empty (orig_num_tokens : int , padded_num_tokens : int ) -> bool :
31- return (padded_num_tokens // 2 ) >= orig_num_tokens
30+ def is_last_ubatch_empty (
31+ orig_num_tokens : int , padded_num_tokens : int , num_ubatches : int
32+ ) -> bool :
33+ return (padded_num_tokens // num_ubatches ) * (num_ubatches - 1 ) >= orig_num_tokens
3234
3335
3436def check_ubatch_thresholds (
3537 config : ParallelConfig , num_tokens : int , uniform_decode : bool
3638) -> bool :
37- if not config .enable_dbo :
39+ if not config .use_ubatching :
3840 return False
3941 if uniform_decode :
4042 return num_tokens >= config .dbo_decode_token_threshold
4143 else :
4244 return num_tokens >= config .dbo_prefill_token_threshold
4345
4446
45- # This just pads the second ubatch slice out to the total number of tokens
47+ # This pads the last ubatch slice out to the total number of tokens
4648# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
4749def _pad_out_ubatch_slices (
4850 ubatch_slices : UBatchSlices , num_total_tokens : int , num_reqs_padded : int
4951) -> UBatchSlices :
50- # TODO(lucas): handle empty second ubatch
51- padded_second_request_slice = slice (
52- ubatch_slices [1 ].request_slice .start , num_reqs_padded
53- )
54- padded_second_token_slice = slice (
55- ubatch_slices [1 ].token_slice .start , num_total_tokens
56- )
57- return [
58- ubatch_slices [0 ],
59- UBatchSlice (padded_second_request_slice , padded_second_token_slice ),
52+ last_slice = ubatch_slices [- 1 ]
53+ padded_last_request_slice = slice (last_slice .request_slice .start , num_reqs_padded )
54+ padded_last_token_slice = slice (last_slice .token_slice .start , num_total_tokens )
55+
56+ return ubatch_slices [:- 1 ] + [
57+ UBatchSlice (padded_last_request_slice , padded_last_token_slice )
6058 ]
6159
6260
@@ -65,40 +63,45 @@ def maybe_create_ubatch_slices(
6563 num_scheduled_tokens : np .ndarray ,
6664 num_tokens_padded : int ,
6765 num_reqs_padded : int ,
68- split_point : int | None = None ,
66+ num_ubatches : int ,
67+ split_point : list [int ] | int | None = None ,
6968) -> tuple [UBatchSlices | None , UBatchSlices | None ]:
7069 if not should_ubatch :
7170 return None , None
7271
7372 if split_point is None :
74- split_point = int (num_tokens_padded ) // 2
73+ split_point = int (num_tokens_padded ) // num_ubatches
74+
75+ token_split_points = [split_point * i for i in range (1 , num_ubatches )]
7576
7677 # TODO(lucas): Refactor the gpu_model_runner.py so we can pass
7778 # in cu_num_tokens directly (i.e. query_start_loc)
7879 cu_num_tokens = np .zeros (len (num_scheduled_tokens ) + 1 , dtype = np .int32 )
7980 np .cumsum (num_scheduled_tokens , dtype = np .int32 , out = cu_num_tokens [1 :])
8081
81- first_ubatch_token_slice = slice ( 0 , split_point )
82- second_ubatch_token_slice = slice ( split_point , cu_num_tokens [ - 1 ])
82+ ubatch_slices = []
83+ start_token = 0
8384
84- # Determine request slices using exclusive stop semantics
85- # First ubatch includes requests whose tokens overlap [0, split_point)
86- first_ubatch_req_stop = int (
87- np .searchsorted (cu_num_tokens , split_point , side = "left" )
88- )
89- first_ubatch_req_slice = slice (0 , first_ubatch_req_stop )
85+ # Add the end point to the split points to make iteration easier
86+ all_points = token_split_points + [cu_num_tokens [- 1 ]]
9087
91- # Second ubatch starts at the request that contains the split_point
92- # or the request starting exactly at split_point (if on boundary)
93- second_ubatch_req_start = int (
94- np .searchsorted (cu_num_tokens , split_point , side = "right" ) - 1
95- )
96- second_ubatch_req_slice = slice (second_ubatch_req_start , len (cu_num_tokens ) - 1 )
88+ for end_token in all_points :
89+ token_slice = slice (start_token , end_token )
9790
98- ubatch_slices = [
99- UBatchSlice (first_ubatch_req_slice , first_ubatch_token_slice ),
100- UBatchSlice (second_ubatch_req_slice , second_ubatch_token_slice ),
101- ]
91+ # Determine request slices using exclusive stop semantics
92+ # Ubatch includes requests whose tokens overlap [start_token, end_token)
93+
94+ # Start at the request that contains the start_token
95+ # or the request starting exactly at start_token (if on boundary)
96+ req_start = int (np .searchsorted (cu_num_tokens , start_token , side = "right" ) - 1 )
97+
98+ # Stop at the request that starts at or after end_token
99+ req_stop = int (np .searchsorted (cu_num_tokens , end_token , side = "left" ))
100+
101+ req_slice = slice (req_start , req_stop )
102+ ubatch_slices .append (UBatchSlice (req_slice , token_slice ))
103+
104+ start_token = end_token
102105
103106 ubatch_slices_padded = _pad_out_ubatch_slices (
104107 ubatch_slices , num_tokens_padded , num_reqs_padded
0 commit comments