Skip to content

Commit dff70eb

Browse files
Support for qwen edit plus model. Use the new TextEncodeQwenImageEditPlus. (comfyanonymous#9986)
1 parent 229e7f9 commit dff70eb

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

comfy/text_encoders/llama.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,21 +400,25 @@ def preprocess_embed(self, embed, device):
400400

401401
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
402402
grid = None
403+
position_ids = None
404+
offset = 0
403405
for e in embeds_info:
404406
if e.get("type") == "image":
405407
grid = e.get("extra", None)
406-
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
407408
start = e.get("index")
408-
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
409+
if position_ids is None:
410+
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
411+
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
409412
end = e.get("size") + start
410413
len_max = int(grid.max()) // 2
411414
start_next = len_max + start
412-
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
413-
position_ids[0, start:end] = start
415+
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
416+
position_ids[0, start:end] = start + offset
414417
max_d = int(grid[0][1]) // 2
415-
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
418+
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
416419
max_d = int(grid[0][2]) // 2
417-
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
420+
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
421+
offset += len_max - (end - start)
418422

419423
if grid is None:
420424
position_ids = None

comfy_extras/nodes_qwen.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,61 @@ def encode(self, clip, prompt, vae=None, image=None):
4343
return (conditioning, )
4444

4545

46+
class TextEncodeQwenImageEditPlus:
47+
@classmethod
48+
def INPUT_TYPES(s):
49+
return {"required": {
50+
"clip": ("CLIP", ),
51+
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
52+
},
53+
"optional": {"vae": ("VAE", ),
54+
"image1": ("IMAGE", ),
55+
"image2": ("IMAGE", ),
56+
"image3": ("IMAGE", ),
57+
}}
58+
59+
RETURN_TYPES = ("CONDITIONING",)
60+
FUNCTION = "encode"
61+
62+
CATEGORY = "advanced/conditioning"
63+
64+
def encode(self, clip, prompt, vae=None, image1=None, image2=None, image3=None):
65+
ref_latents = []
66+
images = [image1, image2, image3]
67+
images_vl = []
68+
llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
69+
image_prompt = ""
70+
71+
for i, image in enumerate(images):
72+
if image is not None:
73+
samples = image.movedim(-1, 1)
74+
total = int(384 * 384)
75+
76+
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
77+
width = round(samples.shape[3] * scale_by)
78+
height = round(samples.shape[2] * scale_by)
79+
80+
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
81+
images_vl.append(s.movedim(1, -1))
82+
if vae is not None:
83+
total = int(1024 * 1024)
84+
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
85+
width = round(samples.shape[3] * scale_by / 8.0) * 8
86+
height = round(samples.shape[2] * scale_by / 8.0) * 8
87+
88+
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
89+
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
90+
91+
image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1)
92+
93+
tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template)
94+
conditioning = clip.encode_from_tokens_scheduled(tokens)
95+
if len(ref_latents) > 0:
96+
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
97+
return (conditioning, )
98+
99+
46100
NODE_CLASS_MAPPINGS = {
47101
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
102+
"TextEncodeQwenImageEditPlus": TextEncodeQwenImageEditPlus,
48103
}

0 commit comments

Comments
 (0)