diff --git a/InstantID.py b/InstantID.py index d7de7ea..43cfee3 100644 --- a/InstantID.py +++ b/InstantID.py @@ -51,36 +51,41 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2 out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) return out_img_pil -class CrossAttentionPatchIID: +class CrossAttentionPatch: # forward for patching - def __init__(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0): + def __init__(self, weight, ipadapter, number, cond, uncond, weight_type="original", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False): self.weights = [weight] - self.instantid = [instantid] + self.ipadapters = [ipadapter] self.conds = [cond] self.unconds = [uncond] self.number = number + self.weight_type = [weight_type] self.masks = [mask] self.sigma_start = [sigma_start] self.sigma_end = [sigma_end] + self.unfold_batch = [unfold_batch] self.k_key = str(self.number*2+1) + "_to_k_ip" self.v_key = str(self.number*2+1) + "_to_v_ip" - def set_new_condition(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0): + def set_new_condition(self, weight, ipadapter, number, cond, uncond, weight_type="original", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False): self.weights.append(weight) - self.instantid.append(instantid) + self.ipadapters.append(ipadapter) self.conds.append(cond) self.unconds.append(uncond) self.masks.append(mask) + self.weight_type.append(weight_type) self.sigma_start.append(sigma_start) self.sigma_end.append(sigma_end) + self.unfold_batch.append(unfold_batch) def __call__(self, n, context_attn2, value_attn2, extra_options): org_dtype = n.dtype cond_or_uncond = extra_options["cond_or_uncond"] - - sigma = extra_options["sigmas"][0] if 'sigmas' in extra_options else None - sigma = sigma.item() if sigma else 999999999.9 + sigma = extra_options["sigmas"][0].item() if 'sigmas' in extra_options else 999999999.9 + + # extra options for AnimateDiff + ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None q = n k = context_attn2 @@ -91,41 +96,112 @@ def __call__(self, n, context_attn2, value_attn2, extra_options): out = optimized_attention(q, k, v, extra_options["n_heads"]) _, _, lh, lw = extra_options["original_shape"] - for weight, cond, uncond, instantid, mask, sigma_start, sigma_end in zip(self.weights, self.conds, self.unconds, self.instantid, self.masks, self.sigma_start, self.sigma_end): - if sigma < sigma_start and sigma > sigma_end: - k_cond = instantid.ip_layers.to_kvs[self.k_key](cond).repeat(batch_prompt, 1, 1) - k_uncond = instantid.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1) - v_cond = instantid.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1) - v_uncond = instantid.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1) - - iid_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) - iid_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) - - out_iid = optimized_attention(q, iid_k, iid_v, extra_options["n_heads"]) - out_iid = out_iid * weight - - if mask is not None: - # TODO: needs checking - mask_h = lh / math.sqrt(lh * lw / qs) - mask_h = int(mask_h) + int((qs % int(mask_h)) != 0) - mask_w = qs // mask_h - + for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch): + if sigma > sigma_start or sigma < sigma_end: + continue + + if unfold_batch and cond.shape[0] > 1: + # Check AnimateDiff context window + if ad_params is not None and ad_params["sub_idxs"] is not None: + # if images length matches or exceeds full_length get sub_idx images + if cond.shape[0] >= ad_params["full_length"]: + cond = torch.Tensor(cond[ad_params["sub_idxs"]]) + uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) + # otherwise, need to do more to get proper sub_idxs masks + else: + # check if images length matches full_length - if not, make it match + if cond.shape[0] < ad_params["full_length"]: + cond = torch.cat((cond, cond[-1:].repeat((ad_params["full_length"]-cond.shape[0], 1, 1))), dim=0) + uncond = torch.cat((uncond, uncond[-1:].repeat((ad_params["full_length"]-uncond.shape[0], 1, 1))), dim=0) + # if we have too many remove the excess (should not happen, but just in case) + if cond.shape[0] > ad_params["full_length"]: + cond = cond[:ad_params["full_length"]] + uncond = uncond[:ad_params["full_length"]] + cond = cond[ad_params["sub_idxs"]] + uncond = uncond[ad_params["sub_idxs"]] + + # if we don't have enough reference images repeat the last one until we reach the right size + if cond.shape[0] < batch_prompt: + cond = torch.cat((cond, cond[-1:].repeat((batch_prompt-cond.shape[0], 1, 1))), dim=0) + uncond = torch.cat((uncond, uncond[-1:].repeat((batch_prompt-uncond.shape[0], 1, 1))), dim=0) + # if we have too many remove the exceeding + elif cond.shape[0] > batch_prompt: + cond = cond[:batch_prompt] + uncond = uncond[:batch_prompt] + + k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond) + k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond) + v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond) + v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond) + else: + k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond).repeat(batch_prompt, 1, 1) + k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1) + v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1) + v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1) + + if weight_type.startswith("linear"): + ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) * weight + ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) * weight + else: + ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) + ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) + + if weight_type.startswith("channel"): + # code by Lvmin Zhang at Stanford University as also seen on Fooocus IPAdapter implementation + ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) + ip_v_offset = ip_v - ip_v_mean + _, _, C = ip_k.shape + channel_penalty = float(C) / 1280.0 + W = weight * channel_penalty + ip_k = ip_k * W + ip_v = ip_v_offset + ip_v_mean * W + + out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) + if weight_type.startswith("original"): + out_ip = out_ip * weight + + if mask is not None: + # TODO: needs checking + mask_h = lh / math.sqrt(lh * lw / qs) + mask_h = int(mask_h) + int((qs % int(mask_h)) != 0) + mask_w = qs // mask_h + + # check if using AnimateDiff and sliding context window + if (mask.shape[0] > 1 and ad_params is not None and ad_params["sub_idxs"] is not None): + # if mask length matches or exceeds full_length, just get sub_idx masks, resize, and continue + if mask.shape[0] >= ad_params["full_length"]: + mask_downsample = torch.Tensor(mask[ad_params["sub_idxs"]]) + mask_downsample = F.interpolate(mask_downsample.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1) + # otherwise, need to do more to get proper sub_idxs masks + else: + # resize to needed attention size (to save on memory) + mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1) + # check if mask length matches full_length - if not, make it match + if mask_downsample.shape[0] < ad_params["full_length"]: + mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:].repeat((ad_params["full_length"]-mask_downsample.shape[0], 1, 1))), dim=0) + # if we have too many remove the excess (should not happen, but just in case) + if mask_downsample.shape[0] > ad_params["full_length"]: + mask_downsample = mask_downsample[:ad_params["full_length"]] + # now, select sub_idxs masks + mask_downsample = mask_downsample[ad_params["sub_idxs"]] + # otherwise, perform usual mask interpolation + else: mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1) - # if we don't have enough masks repeat the last one until we reach the right size - if mask_downsample.shape[0] < batch_prompt: - mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:, :, :].repeat((batch_prompt-mask_downsample.shape[0], 1, 1))), dim=0) - # if we have too many remove the exceeding - elif mask_downsample.shape[0] > batch_prompt: - mask_downsample = mask_downsample[:batch_prompt, :, :] - - # repeat the masks - mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1) - mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2]) + # if we don't have enough masks repeat the last one until we reach the right size + if mask_downsample.shape[0] < batch_prompt: + mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:, :, :].repeat((batch_prompt-mask_downsample.shape[0], 1, 1))), dim=0) + # if we have too many remove the exceeding + elif mask_downsample.shape[0] > batch_prompt: + mask_downsample = mask_downsample[:batch_prompt, :, :] + + # repeat the masks + mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1) + mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2]) - out_iid = out_iid * mask_downsample + out_ip = out_ip * mask_downsample - out = out + out_iid + out = out + out_ip return out.to(dtype=org_dtype) @@ -195,12 +271,12 @@ def _set_model_patch_replace(model, patch_kwargs, key): to = model.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} - if "attn2_iid" not in to["patches_replace"]: - to["patches_replace"]["attn2_iid"] = {} - if key not in to["patches_replace"]["attn2_iid"] or not isinstance(to["patches_replace"]["attn2_iid"][key], CrossAttentionPatchIID): - to["patches_replace"]["attn2_iid"][key] = CrossAttentionPatchIID(**patch_kwargs) + if "attn2" not in to["patches_replace"]: + to["patches_replace"]["attn2"] = {} + if key not in to["patches_replace"]["attn2"]: + to["patches_replace"]["attn2"][key] = CrossAttentionPatch(**patch_kwargs) else: - to["patches_replace"]["attn2_iid"][key].set_new_condition(**patch_kwargs) + to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs) class InstantIDModelLoader: @classmethod @@ -381,12 +457,13 @@ def apply_instantid(self, instantid, insightface, image_features, model, positiv patch_kwargs = { "number": 0, "weight": self.weight, - "instantid": self.instantid, + "ipadapter": self.instantid, "cond": image_prompt_embeds, "uncond": uncond_image_prompt_embeds, "mask": attn_mask, "sigma_start": sigma_start, "sigma_end": sigma_end, + "weight_type": "original", } if not is_sdxl: @@ -415,14 +492,14 @@ def apply_instantid(self, instantid, insightface, image_features, model, positiv pos = [] for t in positive: n = [t[0], t[1].copy()] - n[1]['cross_attn_controlnet'] = image_prompt_embeds.cpu() + n[1]['cross_attn_controlnet'] = image_prompt_embeds.to(comfy.model_management.intermediate_device()) pos.append(n) #pos[0][1]['cross_attn_controlnet'] = image_prompt_embeds.cpu() neg = [] for t in negative: n = [t[0], t[1].copy()] - n[1]['cross_attn_controlnet'] = uncond_image_prompt_embeds.cpu() + n[1]['cross_attn_controlnet'] = uncond_image_prompt_embeds.to(comfy.model_management.intermediate_device()) neg.append(n) #neg[0][1]['cross_attn_controlnet'] = uncond_image_prompt_embeds.cpu() diff --git a/instantID.json b/instantID.json deleted file mode 100644 index 48a4eb6..0000000 --- a/instantID.json +++ /dev/null @@ -1,674 +0,0 @@ -{ - "last_node_id": 14, - "last_link_id": 25, - "nodes": [ - { - "id": 8, - "type": "EmptyLatentImage", - "pos": [ - 716, - 1012 - ], - "size": { - "0": 315, - "1": 106 - }, - "flags": {}, - "order": 0, - "mode": 0, - "outputs": [ - { - "name": "LATENT", - "type": "LATENT", - "links": [ - 11 - ], - "shape": 3 - } - ], - "properties": { - "Node name for S&R": "EmptyLatentImage" - }, - "widgets_values": [ - 1024, - 1024, - 1 - ] - }, - { - "id": 9, - "type": "VAEDecode", - "pos": [ - 1544.937899902344, - 359.6265532226564 - ], - "size": { - "0": 210, - "1": 46 - }, - "flags": {}, - "order": 11, - "mode": 0, - "inputs": [ - { - "name": "samples", - "type": "LATENT", - "link": 12 - }, - { - "name": "vae", - "type": "VAE", - "link": 13 - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 14 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "VAEDecode" - } - }, - { - "id": 2, - "type": "InstantIDModelLoader", - "pos": [ - 333, - 122 - ], - "size": { - "0": 315, - "1": 58 - }, - "flags": {}, - "order": 1, - "mode": 0, - "outputs": [ - { - "name": "INSTANTID", - "type": "INSTANTID", - "links": [ - 18 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "InstantIDModelLoader" - }, - "widgets_values": [ - "ip-adapter.bin" - ] - }, - { - "id": 12, - "type": "InsightFaceLoaderIID", - "pos": [ - 321, - 259 - ], - "size": { - "0": 315, - "1": 58 - }, - "flags": {}, - "order": 2, - "mode": 0, - "outputs": [ - { - "name": "INSIGHTFACE", - "type": "INSIGHTFACE", - "links": [ - 17 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "InsightFaceLoaderIID" - }, - "widgets_values": [ - "CPU" - ] - }, - { - "id": 3, - "type": "CheckpointLoaderSimple", - "pos": [ - 175, - 404 - ], - "size": { - "0": 315, - "1": 98 - }, - "flags": {}, - "order": 3, - "mode": 0, - "outputs": [ - { - "name": "MODEL", - "type": "MODEL", - "links": [ - 19 - ], - "shape": 3, - "slot_index": 0 - }, - { - "name": "CLIP", - "type": "CLIP", - "links": [ - 15, - 16 - ], - "shape": 3, - "slot_index": 1 - }, - { - "name": "VAE", - "type": "VAE", - "links": [ - 13 - ], - "shape": 3, - "slot_index": 2 - } - ], - "properties": { - "Node name for S&R": "CheckpointLoaderSimple" - }, - "widgets_values": [ - "sdxl/sd_xl_base_1.0_0.9vae.safetensors" - ] - }, - { - "id": 10, - "type": "PreviewImage", - "pos": [ - 1821.785899902344, - 370.2745532226563 - ], - "size": [ - 710.068173339844, - 756.3098067626954 - ], - "flags": {}, - "order": 12, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 14 - } - ], - "properties": { - "Node name for S&R": "PreviewImage" - } - }, - { - "id": 6, - "type": "CLIPTextEncode", - "pos": [ - 628, - 481 - ], - "size": { - "0": 400, - "1": 200 - }, - "flags": {}, - "order": 6, - "mode": 0, - "inputs": [ - { - "name": "clip", - "type": "CLIP", - "link": 15 - } - ], - "outputs": [ - { - "name": "CONDITIONING", - "type": "CONDITIONING", - "links": [ - 22 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "CLIPTextEncode" - }, - "widgets_values": [ - "anime portrait of a beautiful girl" - ] - }, - { - "id": 4, - "type": "LoadImage", - "pos": [ - 129, - 581 - ], - "size": { - "0": 315, - "1": 314 - }, - "flags": {}, - "order": 4, - "mode": 0, - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "links": [ - 21 - ], - "shape": 3, - "slot_index": 0 - }, - { - "name": "MASK", - "type": "MASK", - "links": null, - "shape": 3 - } - ], - "properties": { - "Node name for S&R": "LoadImage" - }, - "widgets_values": [ - "face4.jpg", - "image" - ] - }, - { - "id": 7, - "type": "CLIPTextEncode", - "pos": [ - 618, - 743 - ], - "size": { - "0": 400, - "1": 200 - }, - "flags": {}, - "order": 7, - "mode": 0, - "inputs": [ - { - "name": "clip", - "type": "CLIP", - "link": 16 - } - ], - "outputs": [ - { - "name": "CONDITIONING", - "type": "CONDITIONING", - "links": [ - 10 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "CLIPTextEncode" - }, - "widgets_values": [ - "blurry, lowres, malformed, horror" - ] - }, - { - "id": 5, - "type": "KSampler", - "pos": [ - 1129.6658999023439, - 366.2815532226564 - ], - "size": { - "0": 315, - "1": 262 - }, - "flags": {}, - "order": 10, - "mode": 0, - "inputs": [ - { - "name": "model", - "type": "MODEL", - "link": 20 - }, - { - "name": "positive", - "type": "CONDITIONING", - "link": 23, - "slot_index": 1 - }, - { - "name": "negative", - "type": "CONDITIONING", - "link": 10 - }, - { - "name": "latent_image", - "type": "LATENT", - "link": 11, - "slot_index": 3 - } - ], - "outputs": [ - { - "name": "LATENT", - "type": "LATENT", - "links": [ - 12 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "KSampler" - }, - "widgets_values": [ - 611076132277670, - "fixed", - 20, - 8, - "euler", - "normal", - 1 - ] - }, - { - "id": 11, - "type": "ApplyInstantID", - "pos": [ - 758.7900097656251, - 238.68000427246096 - ], - "size": { - "0": 216.59999084472656, - "1": 86 - }, - "flags": {}, - "order": 8, - "mode": 0, - "inputs": [ - { - "name": "instantid", - "type": "INSTANTID", - "link": 18 - }, - { - "name": "insightface", - "type": "INSIGHTFACE", - "link": 17 - }, - { - "name": "model", - "type": "MODEL", - "link": 19 - }, - { - "name": "image", - "type": "IMAGE", - "link": 21 - } - ], - "outputs": [ - { - "name": "MODEL", - "type": "MODEL", - "links": [ - 20 - ], - "shape": 3, - "slot_index": 0 - }, - { - "name": "IMAGE_KPS", - "type": "IMAGE", - "links": [ - 24 - ], - "shape": 3, - "slot_index": 1 - } - ], - "properties": { - "Node name for S&R": "ApplyInstantID" - } - }, - { - "id": 14, - "type": "ControlNetLoader", - "pos": [ - 676, - 9 - ], - "size": { - "0": 315, - "1": 58 - }, - "flags": {}, - "order": 5, - "mode": 0, - "outputs": [ - { - "name": "CONTROL_NET", - "type": "CONTROL_NET", - "links": [ - 25 - ], - "shape": 3 - } - ], - "properties": { - "Node name for S&R": "ControlNetLoader" - }, - "widgets_values": [ - "instantid/diffusion_pytorch_model.safetensors" - ] - }, - { - "id": 13, - "type": "ControlNetApply", - "pos": [ - 1085, - 122 - ], - "size": { - "0": 317.4000244140625, - "1": 98 - }, - "flags": {}, - "order": 9, - "mode": 0, - "inputs": [ - { - "name": "conditioning", - "type": "CONDITIONING", - "link": 22 - }, - { - "name": "control_net", - "type": "CONTROL_NET", - "link": 25, - "slot_index": 1 - }, - { - "name": "image", - "type": "IMAGE", - "link": 24 - } - ], - "outputs": [ - { - "name": "CONDITIONING", - "type": "CONDITIONING", - "links": [ - 23 - ], - "shape": 3, - "slot_index": 0 - } - ], - "properties": { - "Node name for S&R": "ControlNetApply" - }, - "widgets_values": [ - 0.3 - ] - } - ], - "links": [ - [ - 10, - 7, - 0, - 5, - 2, - "CONDITIONING" - ], - [ - 11, - 8, - 0, - 5, - 3, - "LATENT" - ], - [ - 12, - 5, - 0, - 9, - 0, - "LATENT" - ], - [ - 13, - 3, - 2, - 9, - 1, - "VAE" - ], - [ - 14, - 9, - 0, - 10, - 0, - "IMAGE" - ], - [ - 15, - 3, - 1, - 6, - 0, - "CLIP" - ], - [ - 16, - 3, - 1, - 7, - 0, - "CLIP" - ], - [ - 17, - 12, - 0, - 11, - 1, - "INSIGHTFACE" - ], - [ - 18, - 2, - 0, - 11, - 0, - "INSTANTID" - ], - [ - 19, - 3, - 0, - 11, - 2, - "MODEL" - ], - [ - 20, - 11, - 0, - 5, - 0, - "MODEL" - ], - [ - 21, - 4, - 0, - 11, - 3, - "IMAGE" - ], - [ - 22, - 6, - 0, - 13, - 0, - "CONDITIONING" - ], - [ - 23, - 13, - 0, - 5, - 1, - "CONDITIONING" - ], - [ - 24, - 11, - 1, - 13, - 2, - "IMAGE" - ], - [ - 25, - 14, - 0, - 13, - 1, - "CONTROL_NET" - ] - ], - "groups": [], - "config": {}, - "extra": {}, - "version": 0.4 -} \ No newline at end of file