Skip to content

Commit

Permalink
fix attention patching and compatibility with IPAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Feb 13, 2024
1 parent 4872805 commit 106f031
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 721 deletions.
171 changes: 124 additions & 47 deletions InstantID.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,41 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
return out_img_pil

class CrossAttentionPatchIID:
class CrossAttentionPatch:
# forward for patching
def __init__(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0):
def __init__(self, weight, ipadapter, number, cond, uncond, weight_type="original", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False):
self.weights = [weight]
self.instantid = [instantid]
self.ipadapters = [ipadapter]
self.conds = [cond]
self.unconds = [uncond]
self.number = number
self.weight_type = [weight_type]
self.masks = [mask]
self.sigma_start = [sigma_start]
self.sigma_end = [sigma_end]
self.unfold_batch = [unfold_batch]

self.k_key = str(self.number*2+1) + "_to_k_ip"
self.v_key = str(self.number*2+1) + "_to_v_ip"

def set_new_condition(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0):
def set_new_condition(self, weight, ipadapter, number, cond, uncond, weight_type="original", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False):
self.weights.append(weight)
self.instantid.append(instantid)
self.ipadapters.append(ipadapter)
self.conds.append(cond)
self.unconds.append(uncond)
self.masks.append(mask)
self.weight_type.append(weight_type)
self.sigma_start.append(sigma_start)
self.sigma_end.append(sigma_end)
self.unfold_batch.append(unfold_batch)

def __call__(self, n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
cond_or_uncond = extra_options["cond_or_uncond"]

sigma = extra_options["sigmas"][0] if 'sigmas' in extra_options else None
sigma = sigma.item() if sigma else 999999999.9
sigma = extra_options["sigmas"][0].item() if 'sigmas' in extra_options else 999999999.9

# extra options for AnimateDiff
ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None

q = n
k = context_attn2
Expand All @@ -91,41 +96,112 @@ def __call__(self, n, context_attn2, value_attn2, extra_options):
out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, lh, lw = extra_options["original_shape"]

for weight, cond, uncond, instantid, mask, sigma_start, sigma_end in zip(self.weights, self.conds, self.unconds, self.instantid, self.masks, self.sigma_start, self.sigma_end):
if sigma < sigma_start and sigma > sigma_end:
k_cond = instantid.ip_layers.to_kvs[self.k_key](cond).repeat(batch_prompt, 1, 1)
k_uncond = instantid.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1)
v_cond = instantid.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1)
v_uncond = instantid.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1)

iid_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0)
iid_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0)

out_iid = optimized_attention(q, iid_k, iid_v, extra_options["n_heads"])
out_iid = out_iid * weight

if mask is not None:
# TODO: needs checking
mask_h = lh / math.sqrt(lh * lw / qs)
mask_h = int(mask_h) + int((qs % int(mask_h)) != 0)
mask_w = qs // mask_h

for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch):
if sigma > sigma_start or sigma < sigma_end:
continue

if unfold_batch and cond.shape[0] > 1:
# Check AnimateDiff context window
if ad_params is not None and ad_params["sub_idxs"] is not None:
# if images length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = torch.Tensor(cond[ad_params["sub_idxs"]])
uncond = torch.Tensor(uncond[ad_params["sub_idxs"]])
# otherwise, need to do more to get proper sub_idxs masks
else:
# check if images length matches full_length - if not, make it match
if cond.shape[0] < ad_params["full_length"]:
cond = torch.cat((cond, cond[-1:].repeat((ad_params["full_length"]-cond.shape[0], 1, 1))), dim=0)
uncond = torch.cat((uncond, uncond[-1:].repeat((ad_params["full_length"]-uncond.shape[0], 1, 1))), dim=0)
# if we have too many remove the excess (should not happen, but just in case)
if cond.shape[0] > ad_params["full_length"]:
cond = cond[:ad_params["full_length"]]
uncond = uncond[:ad_params["full_length"]]
cond = cond[ad_params["sub_idxs"]]
uncond = uncond[ad_params["sub_idxs"]]

# if we don't have enough reference images repeat the last one until we reach the right size
if cond.shape[0] < batch_prompt:
cond = torch.cat((cond, cond[-1:].repeat((batch_prompt-cond.shape[0], 1, 1))), dim=0)
uncond = torch.cat((uncond, uncond[-1:].repeat((batch_prompt-uncond.shape[0], 1, 1))), dim=0)
# if we have too many remove the exceeding
elif cond.shape[0] > batch_prompt:
cond = cond[:batch_prompt]
uncond = uncond[:batch_prompt]

k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond)
k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond)
v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond)
v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond)
else:
k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond).repeat(batch_prompt, 1, 1)
k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1)
v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1)
v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1)

if weight_type.startswith("linear"):
ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) * weight
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) * weight
else:
ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0)
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0)

if weight_type.startswith("channel"):
# code by Lvmin Zhang at Stanford University as also seen on Fooocus IPAdapter implementation
ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True)
ip_v_offset = ip_v - ip_v_mean
_, _, C = ip_k.shape
channel_penalty = float(C) / 1280.0
W = weight * channel_penalty
ip_k = ip_k * W
ip_v = ip_v_offset + ip_v_mean * W

out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
if weight_type.startswith("original"):
out_ip = out_ip * weight

if mask is not None:
# TODO: needs checking
mask_h = lh / math.sqrt(lh * lw / qs)
mask_h = int(mask_h) + int((qs % int(mask_h)) != 0)
mask_w = qs // mask_h

# check if using AnimateDiff and sliding context window
if (mask.shape[0] > 1 and ad_params is not None and ad_params["sub_idxs"] is not None):
# if mask length matches or exceeds full_length, just get sub_idx masks, resize, and continue
if mask.shape[0] >= ad_params["full_length"]:
mask_downsample = torch.Tensor(mask[ad_params["sub_idxs"]])
mask_downsample = F.interpolate(mask_downsample.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1)
# otherwise, need to do more to get proper sub_idxs masks
else:
# resize to needed attention size (to save on memory)
mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1)
# check if mask length matches full_length - if not, make it match
if mask_downsample.shape[0] < ad_params["full_length"]:
mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:].repeat((ad_params["full_length"]-mask_downsample.shape[0], 1, 1))), dim=0)
# if we have too many remove the excess (should not happen, but just in case)
if mask_downsample.shape[0] > ad_params["full_length"]:
mask_downsample = mask_downsample[:ad_params["full_length"]]
# now, select sub_idxs masks
mask_downsample = mask_downsample[ad_params["sub_idxs"]]
# otherwise, perform usual mask interpolation
else:
mask_downsample = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bicubic").squeeze(1)

# if we don't have enough masks repeat the last one until we reach the right size
if mask_downsample.shape[0] < batch_prompt:
mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:, :, :].repeat((batch_prompt-mask_downsample.shape[0], 1, 1))), dim=0)
# if we have too many remove the exceeding
elif mask_downsample.shape[0] > batch_prompt:
mask_downsample = mask_downsample[:batch_prompt, :, :]
# repeat the masks
mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1)
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2])
# if we don't have enough masks repeat the last one until we reach the right size
if mask_downsample.shape[0] < batch_prompt:
mask_downsample = torch.cat((mask_downsample, mask_downsample[-1:, :, :].repeat((batch_prompt-mask_downsample.shape[0], 1, 1))), dim=0)
# if we have too many remove the exceeding
elif mask_downsample.shape[0] > batch_prompt:
mask_downsample = mask_downsample[:batch_prompt, :, :]

# repeat the masks
mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1)
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2])

out_iid = out_iid * mask_downsample
out_ip = out_ip * mask_downsample

out = out + out_iid
out = out + out_ip

return out.to(dtype=org_dtype)

Expand Down Expand Up @@ -195,12 +271,12 @@ def _set_model_patch_replace(model, patch_kwargs, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2_iid" not in to["patches_replace"]:
to["patches_replace"]["attn2_iid"] = {}
if key not in to["patches_replace"]["attn2_iid"] or not isinstance(to["patches_replace"]["attn2_iid"][key], CrossAttentionPatchIID):
to["patches_replace"]["attn2_iid"][key] = CrossAttentionPatchIID(**patch_kwargs)
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
if key not in to["patches_replace"]["attn2"]:
to["patches_replace"]["attn2"][key] = CrossAttentionPatch(**patch_kwargs)
else:
to["patches_replace"]["attn2_iid"][key].set_new_condition(**patch_kwargs)
to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs)

class InstantIDModelLoader:
@classmethod
Expand Down Expand Up @@ -381,12 +457,13 @@ def apply_instantid(self, instantid, insightface, image_features, model, positiv
patch_kwargs = {
"number": 0,
"weight": self.weight,
"instantid": self.instantid,
"ipadapter": self.instantid,
"cond": image_prompt_embeds,
"uncond": uncond_image_prompt_embeds,
"mask": attn_mask,
"sigma_start": sigma_start,
"sigma_end": sigma_end,
"weight_type": "original",
}

if not is_sdxl:
Expand Down Expand Up @@ -415,14 +492,14 @@ def apply_instantid(self, instantid, insightface, image_features, model, positiv
pos = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['cross_attn_controlnet'] = image_prompt_embeds.cpu()
n[1]['cross_attn_controlnet'] = image_prompt_embeds.to(comfy.model_management.intermediate_device())
pos.append(n)
#pos[0][1]['cross_attn_controlnet'] = image_prompt_embeds.cpu()

neg = []
for t in negative:
n = [t[0], t[1].copy()]
n[1]['cross_attn_controlnet'] = uncond_image_prompt_embeds.cpu()
n[1]['cross_attn_controlnet'] = uncond_image_prompt_embeds.to(comfy.model_management.intermediate_device())
neg.append(n)
#neg[0][1]['cross_attn_controlnet'] = uncond_image_prompt_embeds.cpu()

Expand Down
Loading

0 comments on commit 106f031

Please sign in to comment.