@@ -333,21 +333,25 @@ def __init__(
333333 self .proj_out = operations .Linear (self .inner_dim , patch_size * patch_size * self .out_channels , bias = True , dtype = dtype , device = device )
334334 self .gradient_checkpointing = False
335335
336- def pos_embeds (self , x , context ):
336+ def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 ):
337337 bs , c , t , h , w = x .shape
338338 patch_size = self .patch_size
339+ hidden_states = comfy .ldm .common_dit .pad_to_patch_size (x , (1 , self .patch_size , self .patch_size ))
340+ orig_shape = hidden_states .shape
341+ hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [1 ], orig_shape [- 2 ] // 2 , 2 , orig_shape [- 1 ] // 2 , 2 )
342+ hidden_states = hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 )
343+ hidden_states = hidden_states .reshape (orig_shape [0 ], (orig_shape [- 2 ] // 2 ) * (orig_shape [- 1 ] // 2 ), orig_shape [1 ] * 4 )
339344 h_len = ((h + (patch_size // 2 )) // patch_size )
340345 w_len = ((w + (patch_size // 2 )) // patch_size )
341346
342- img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
343- img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
344- img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
345- img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
347+ h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
348+ w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
346349
347- txt_start = round (max (h_len , w_len ))
348- txt_ids = torch .linspace (txt_start , txt_start + context .shape [1 ], steps = context .shape [1 ], device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 ).repeat (bs , 1 , 3 )
349- ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
350- return self .pe_embedder (ids ).squeeze (1 ).unsqueeze (2 ).to (x .dtype )
350+ img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
351+ img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
352+ img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
353+ img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
354+ return hidden_states , repeat (img_ids , "h w c -> b (h w) c" , b = bs ), orig_shape
351355
352356 def forward (
353357 self ,
@@ -363,13 +367,13 @@ def forward(
363367 encoder_hidden_states = context
364368 encoder_hidden_states_mask = attention_mask
365369
366- image_rotary_emb = self .pos_embeds (x , context )
370+ hidden_states , img_ids , orig_shape = self .process_img (x )
371+ num_embeds = hidden_states .shape [1 ]
367372
368- hidden_states = comfy .ldm .common_dit .pad_to_patch_size (x , (1 , self .patch_size , self .patch_size ))
369- orig_shape = hidden_states .shape
370- hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [1 ], orig_shape [- 2 ] // 2 , 2 , orig_shape [- 1 ] // 2 , 2 )
371- hidden_states = hidden_states .permute (0 , 2 , 4 , 1 , 3 , 5 )
372- hidden_states = hidden_states .reshape (orig_shape [0 ], (orig_shape [- 2 ] // 2 ) * (orig_shape [- 1 ] // 2 ), orig_shape [1 ] * 4 )
373+ txt_start = round (max (((x .shape [- 1 ] + (self .patch_size // 2 )) // self .patch_size ), ((x .shape [- 2 ] + (self .patch_size // 2 )) // self .patch_size )))
374+ txt_ids = torch .linspace (txt_start , txt_start + context .shape [1 ], steps = context .shape [1 ], device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
375+ ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
376+ image_rotary_emb = self .pe_embedder (ids ).squeeze (1 ).unsqueeze (2 ).to (x .dtype )
373377
374378 hidden_states = self .img_in (hidden_states )
375379 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
@@ -408,6 +412,6 @@ def block_wrap(args):
408412 hidden_states = self .norm_out (hidden_states , temb )
409413 hidden_states = self .proj_out (hidden_states )
410414
411- hidden_states = hidden_states .view (orig_shape [0 ], orig_shape [- 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2 )
415+ hidden_states = hidden_states [:, : num_embeds ] .view (orig_shape [0 ], orig_shape [- 2 ] // 2 , orig_shape [- 1 ] // 2 , orig_shape [1 ], 2 , 2 )
412416 hidden_states = hidden_states .permute (0 , 3 , 1 , 4 , 2 , 5 )
413417 return hidden_states .reshape (orig_shape )[:, :, :, :x .shape [- 2 ], :x .shape [- 1 ]]
0 commit comments