Skip to content

Commit

Permalink
add compatibility with ipadapter
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Feb 11, 2024
1 parent 504b6d3 commit 047f969
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions InstantID.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
return out_img_pil

class CrossAttentionPatch:
class CrossAttentionPatchIID:
# forward for patching
def __init__(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0):
self.weights = [weight]
Expand Down Expand Up @@ -191,14 +191,14 @@ def __init__(self, state_dict):
self.to_kvs[k] = torch.nn.Linear(value.shape[1], value.shape[0], bias=False)
self.to_kvs[k].weight.data = value

def set_model_patch_replace(model, patch_kwargs, key):
def _set_model_patch_replace(model, patch_kwargs, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
if key not in to["patches_replace"]["attn2"]:
patch = CrossAttentionPatch(**patch_kwargs)
if key not in to["patches_replace"]["attn2"] or not isinstance(to["patches_replace"]["attn2"][key], CrossAttentionPatchIID):
patch = CrossAttentionPatchIID(**patch_kwargs)
to["patches_replace"]["attn2"][key] = patch
else:
to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs)
Expand Down Expand Up @@ -388,25 +388,25 @@ def apply_instantid(self, instantid, insightface, image_features, model, positiv

if not is_sdxl:
for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention
set_model_patch_replace(work_model, patch_kwargs, ("input", id))
_set_model_patch_replace(work_model, patch_kwargs, ("input", id))
patch_kwargs["number"] += 1
for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention
set_model_patch_replace(work_model, patch_kwargs, ("output", id))
_set_model_patch_replace(work_model, patch_kwargs, ("output", id))
patch_kwargs["number"] += 1
set_model_patch_replace(work_model, patch_kwargs, ("middle", 0))
_set_model_patch_replace(work_model, patch_kwargs, ("middle", 0))
else:
for id in [4,5,7,8]: # id of input_blocks that have cross attention
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(work_model, patch_kwargs, ("input", id, index))
_set_model_patch_replace(work_model, patch_kwargs, ("input", id, index))
patch_kwargs["number"] += 1
for id in range(6): # id of output_blocks that have cross attention
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(work_model, patch_kwargs, ("output", id, index))
_set_model_patch_replace(work_model, patch_kwargs, ("output", id, index))
patch_kwargs["number"] += 1
for index in range(10):
set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index))
_set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index))
patch_kwargs["number"] += 1

pos = []
Expand Down

0 comments on commit 047f969

Please sign in to comment.