Skip to content

Commit

Permalink
align attention with ipadapter
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Apr 13, 2024
1 parent 23b2968 commit 8b7932a
Showing 1 changed file with 15 additions and 28 deletions.
43 changes: 15 additions & 28 deletions CrossAttentionPatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,53 +82,40 @@ def __call__(self, q, k, v, extra_options):
cond = cond_alt[t_idx]
del cond_alt

#if isinstance(weight, torch.Tensor):
# weight = tensor_to_size(weight, batch_prompt)
# weight = weight.repeat(len(cond_or_uncond), 1, 1)
#elif weight == 0:
# continue

if unfold_batch and cond.shape[0] > 1:
if unfold_batch:
# Check AnimateDiff context window
if ad_params is not None and ad_params["sub_idxs"] is not None:
# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
if isinstance(weight, torch.Tensor):
weight = torch.Tensor(weight["sub_idxs"])
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
if torch.all(weight == 0):
continue
elif weight == 0:
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, ad_params["full_length"])
weight = torch.Tensor(weight[ad_params["sub_idxs"]])
if torch.all(weight == 0):
continue
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
continue

# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = torch.Tensor(cond[ad_params["sub_idxs"]])
uncond = torch.Tensor(uncond[ad_params["sub_idxs"]])
# otherwise get sub_idxs images
else:
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, ad_params["full_length"])
weight = weight[ad_params["sub_idxs"]]
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
if torch.all(weight == 0):
continue
elif weight == 0:
continue
cond = tensor_to_size(cond, ad_params["full_length"])
uncond = tensor_to_size(uncond, ad_params["full_length"])
cond = cond[ad_params["sub_idxs"]]
uncond = uncond[ad_params["sub_idxs"]]
else:
cond = tensor_to_size(cond, batch_prompt)
uncond = tensor_to_size(uncond, batch_prompt)

if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, batch_prompt)
weight = weight.repeat(len(cond_or_uncond), 1, 1)
if torch.all(weight == 0):
continue
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
continue

cond = tensor_to_size(cond, batch_prompt)
uncond = tensor_to_size(uncond, batch_prompt)

k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond)
k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond)
v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond)
Expand All @@ -137,9 +124,9 @@ def __call__(self, q, k, v, extra_options):
# TODO: should we always convert the weights to a tensor?
if isinstance(weight, torch.Tensor):
weight = tensor_to_size(weight, batch_prompt)
weight = weight.repeat(len(cond_or_uncond), 1, 1)
if torch.all(weight == 0):
continue
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
continue

Expand Down

0 comments on commit 8b7932a

Please sign in to comment.