From 8b7932a66bf43b5e2499e2ba77df5c373af006da Mon Sep 17 00:00:00 2001 From: matt3o Date: Sat, 13 Apr 2024 11:44:53 +0200 Subject: [PATCH] align attention with ipadapter --- CrossAttentionPatch.py | 43 +++++++++++++++--------------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/CrossAttentionPatch.py b/CrossAttentionPatch.py index e0dd5c5..ac2d0fe 100644 --- a/CrossAttentionPatch.py +++ b/CrossAttentionPatch.py @@ -82,53 +82,40 @@ def __call__(self, q, k, v, extra_options): cond = cond_alt[t_idx] del cond_alt - #if isinstance(weight, torch.Tensor): - # weight = tensor_to_size(weight, batch_prompt) - # weight = weight.repeat(len(cond_or_uncond), 1, 1) - #elif weight == 0: - # continue - - if unfold_batch and cond.shape[0] > 1: + if unfold_batch: # Check AnimateDiff context window if ad_params is not None and ad_params["sub_idxs"] is not None: - # if image length matches or exceeds full_length get sub_idx images - if cond.shape[0] >= ad_params["full_length"]: - if isinstance(weight, torch.Tensor): - weight = torch.Tensor(weight["sub_idxs"]) - weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond - if torch.all(weight == 0): - continue - elif weight == 0: + if isinstance(weight, torch.Tensor): + weight = tensor_to_size(weight, ad_params["full_length"]) + weight = torch.Tensor(weight[ad_params["sub_idxs"]]) + if torch.all(weight == 0): continue + weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond + elif weight == 0: + continue + # if image length matches or exceeds full_length get sub_idx images + if cond.shape[0] >= ad_params["full_length"]: cond = torch.Tensor(cond[ad_params["sub_idxs"]]) uncond = torch.Tensor(uncond[ad_params["sub_idxs"]]) # otherwise get sub_idxs images else: - if isinstance(weight, torch.Tensor): - weight = tensor_to_size(weight, ad_params["full_length"]) - weight = weight[ad_params["sub_idxs"]] - weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond - if torch.all(weight == 0): - continue - elif weight == 0: - continue cond = tensor_to_size(cond, ad_params["full_length"]) uncond = tensor_to_size(uncond, ad_params["full_length"]) cond = cond[ad_params["sub_idxs"]] uncond = uncond[ad_params["sub_idxs"]] else: - cond = tensor_to_size(cond, batch_prompt) - uncond = tensor_to_size(uncond, batch_prompt) - if isinstance(weight, torch.Tensor): weight = tensor_to_size(weight, batch_prompt) - weight = weight.repeat(len(cond_or_uncond), 1, 1) if torch.all(weight == 0): continue + weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond elif weight == 0: continue + cond = tensor_to_size(cond, batch_prompt) + uncond = tensor_to_size(uncond, batch_prompt) + k_cond = ipadapter.ip_layers.to_kvs[self.k_key](cond) k_uncond = ipadapter.ip_layers.to_kvs[self.k_key](uncond) v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond) @@ -137,9 +124,9 @@ def __call__(self, q, k, v, extra_options): # TODO: should we always convert the weights to a tensor? if isinstance(weight, torch.Tensor): weight = tensor_to_size(weight, batch_prompt) - weight = weight.repeat(len(cond_or_uncond), 1, 1) if torch.all(weight == 0): continue + weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond elif weight == 0: continue