@@ -210,7 +210,7 @@ def block_wrap(args):
210210 img = self .final_layer (img , vec ) # (N, T, patch_size ** 2 * out_channels)
211211 return img
212212
213- def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 ):
213+ def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 , transformer_options = {} ):
214214 bs , c , h , w = x .shape
215215 patch_size = self .patch_size
216216 x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
@@ -222,10 +222,22 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
222222 h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
223223 w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
224224
225- img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
225+ steps_h = h_len
226+ steps_w = w_len
227+
228+ rope_options = transformer_options .get ("rope_options" , None )
229+ if rope_options is not None :
230+ h_len = (h_len - 1.0 ) * rope_options .get ("scale_y" , 1.0 ) + 1.0
231+ w_len = (w_len - 1.0 ) * rope_options .get ("scale_x" , 1.0 ) + 1.0
232+
233+ index += rope_options .get ("shift_t" , 0.0 )
234+ h_offset += rope_options .get ("shift_y" , 0.0 )
235+ w_offset += rope_options .get ("shift_x" , 0.0 )
236+
237+ img_ids = torch .zeros ((steps_h , steps_w , 3 ), device = x .device , dtype = x .dtype )
226238 img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
227- img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
228- img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
239+ img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = steps_h , device = x .device , dtype = x .dtype ).unsqueeze (1 )
240+ img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = steps_w , device = x .device , dtype = x .dtype ).unsqueeze (0 )
229241 return img , repeat (img_ids , "h w c -> b (h w) c" , b = bs )
230242
231243 def forward (self , x , timestep , context , y = None , guidance = None , ref_latents = None , control = None , transformer_options = {}, ** kwargs ):
@@ -241,7 +253,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
241253
242254 h_len = ((h_orig + (patch_size // 2 )) // patch_size )
243255 w_len = ((w_orig + (patch_size // 2 )) // patch_size )
244- img , img_ids = self .process_img (x )
256+ img , img_ids = self .process_img (x , transformer_options = transformer_options )
245257 img_tokens = img .shape [1 ]
246258 if ref_latents is not None :
247259 h = 0
0 commit comments