Skip to content

Commit 2abd2b5

Browse files
Make ScaleROPE node work on Flux. (#10686)
1 parent a1a7036 commit 2abd2b5

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

comfy/ldm/flux/model.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def block_wrap(args):
210210
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
211211
return img
212212

213-
def process_img(self, x, index=0, h_offset=0, w_offset=0):
213+
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
214214
bs, c, h, w = x.shape
215215
patch_size = self.patch_size
216216
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@@ -222,10 +222,22 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
222222
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
223223
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
224224

225-
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
225+
steps_h = h_len
226+
steps_w = w_len
227+
228+
rope_options = transformer_options.get("rope_options", None)
229+
if rope_options is not None:
230+
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
231+
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
232+
233+
index += rope_options.get("shift_t", 0.0)
234+
h_offset += rope_options.get("shift_y", 0.0)
235+
w_offset += rope_options.get("shift_x", 0.0)
236+
237+
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
226238
img_ids[:, :, 0] = img_ids[:, :, 1] + index
227-
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
228-
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
239+
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
240+
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
229241
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
230242

231243
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@@ -241,7 +253,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
241253

242254
h_len = ((h_orig + (patch_size // 2)) // patch_size)
243255
w_len = ((w_orig + (patch_size // 2)) // patch_size)
244-
img, img_ids = self.process_img(x)
256+
img, img_ids = self.process_img(x, transformer_options=transformer_options)
245257
img_tokens = img.shape[1]
246258
if ref_latents is not None:
247259
h = 0

0 commit comments

Comments
 (0)