From 047f96960981c48cefe833b128bc5127aa597bf7 Mon Sep 17 00:00:00 2001 From: matt3o Date: Sun, 11 Feb 2024 10:32:50 +0100 Subject: [PATCH] add compatibility with ipadapter --- InstantID.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/InstantID.py b/InstantID.py index fa25c37..818e3d1 100644 --- a/InstantID.py +++ b/InstantID.py @@ -51,7 +51,7 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2 out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) return out_img_pil -class CrossAttentionPatch: +class CrossAttentionPatchIID: # forward for patching def __init__(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0): self.weights = [weight] @@ -191,14 +191,14 @@ def __init__(self, state_dict): self.to_kvs[k] = torch.nn.Linear(value.shape[1], value.shape[0], bias=False) self.to_kvs[k].weight.data = value -def set_model_patch_replace(model, patch_kwargs, key): +def _set_model_patch_replace(model, patch_kwargs, key): to = model.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if "attn2" not in to["patches_replace"]: to["patches_replace"]["attn2"] = {} - if key not in to["patches_replace"]["attn2"]: - patch = CrossAttentionPatch(**patch_kwargs) + if key not in to["patches_replace"]["attn2"] or not isinstance(to["patches_replace"]["attn2"][key], CrossAttentionPatchIID): + patch = CrossAttentionPatchIID(**patch_kwargs) to["patches_replace"]["attn2"][key] = patch else: to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs) @@ -388,25 +388,25 @@ def apply_instantid(self, instantid, insightface, image_features, model, positiv if not is_sdxl: for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention - set_model_patch_replace(work_model, patch_kwargs, ("input", id)) + _set_model_patch_replace(work_model, patch_kwargs, ("input", id)) patch_kwargs["number"] += 1 for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention - set_model_patch_replace(work_model, patch_kwargs, ("output", id)) + _set_model_patch_replace(work_model, patch_kwargs, ("output", id)) patch_kwargs["number"] += 1 - set_model_patch_replace(work_model, patch_kwargs, ("middle", 0)) + _set_model_patch_replace(work_model, patch_kwargs, ("middle", 0)) else: for id in [4,5,7,8]: # id of input_blocks that have cross attention block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth for index in block_indices: - set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) + _set_model_patch_replace(work_model, patch_kwargs, ("input", id, index)) patch_kwargs["number"] += 1 for id in range(6): # id of output_blocks that have cross attention block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth for index in block_indices: - set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) + _set_model_patch_replace(work_model, patch_kwargs, ("output", id, index)) patch_kwargs["number"] += 1 for index in range(10): - set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) + _set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index)) patch_kwargs["number"] += 1 pos = []