@@ -1482,14 +1482,9 @@ def get_rope_index(
14821482 if attention_mask is None :
14831483 attention_mask = torch .ones_like (total_input_ids )
14841484 position_ids = torch .ones (
1485- 3 ,
1486- input_ids .shape [0 ],
1487- input_ids .shape [1 ],
1488- dtype = input_ids .dtype ,
1489- device = input_ids .device ,
1485+ 3 , input_ids .shape [0 ], input_ids .shape [1 ], dtype = input_ids .dtype , device = input_ids .device
14901486 )
14911487 image_index , video_index = 0 , 0
1492- attention_mask = attention_mask .to (total_input_ids .device )
14931488 for i , input_ids in enumerate (total_input_ids ):
14941489 input_ids = input_ids [attention_mask [i ] == 1 ]
14951490 image_nums , video_nums = 0 , 0
@@ -1516,21 +1511,15 @@ def get_rope_index(
15161511 image_grid_thw [image_index ][1 ],
15171512 image_grid_thw [image_index ][2 ],
15181513 )
1519- second_per_grid_t = 0
15201514 image_index += 1
15211515 remain_images -= 1
15221516 ed = ed_image
1523-
15241517 else :
15251518 t , h , w = (
15261519 video_grid_thw [video_index ][0 ],
15271520 video_grid_thw [video_index ][1 ],
15281521 video_grid_thw [video_index ][2 ],
15291522 )
1530- if second_per_grid_ts is not None :
1531- second_per_grid_t = second_per_grid_ts [video_index ]
1532- else :
1533- second_per_grid_t = 1.0
15341523 video_index += 1
15351524 remain_videos -= 1
15361525 ed = ed_video
@@ -1544,15 +1533,7 @@ def get_rope_index(
15441533 st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
15451534 llm_pos_ids_list .append (torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx )
15461535
1547- t_index = (
1548- (
1549- torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w )
1550- * second_per_grid_t
1551- * self .config .vision_config .tokens_per_second
1552- )
1553- .long ()
1554- .flatten ()
1555- )
1536+ t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w ).flatten ()
15561537 h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (llm_grid_t , - 1 , llm_grid_w ).flatten ()
15571538 w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (llm_grid_t , llm_grid_h , - 1 ).flatten ()
15581539 llm_pos_ids_list .append (torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx )
@@ -1572,7 +1553,7 @@ def get_rope_index(
15721553 if attention_mask is not None :
15731554 position_ids = attention_mask .long ().cumsum (- 1 ) - 1
15741555 position_ids .masked_fill_ (attention_mask == 0 , 1 )
1575- position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 ).to (input_ids .device )
1556+ position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 ).to (attention_mask .device )
15761557 max_position_ids = position_ids .max (0 , keepdim = False )[0 ].max (- 1 , keepdim = True )[0 ]
15771558 mrope_position_deltas = max_position_ids + 1 - attention_mask .shape [- 1 ]
15781559 else :
0 commit comments