From 59d60083130b5248378b899ea8dd5a0f2cf92d0c Mon Sep 17 00:00:00 2001 From: matt3o Date: Tue, 9 Apr 2024 14:04:04 +0200 Subject: [PATCH] align attention patch with ipadapter --- CrossAttentionPatch.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/CrossAttentionPatch.py b/CrossAttentionPatch.py index e3a5ee9..06f153e 100644 --- a/CrossAttentionPatch.py +++ b/CrossAttentionPatch.py @@ -19,7 +19,7 @@ def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=Non self.unfold_batch = [unfold_batch] self.embeds_scaling = [embeds_scaling] self.number = number - self.layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 15 # TODO: check if this is a valid condition to detect all models + self.layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 # TODO: check if this is a valid condition to detect all models self.k_key = str(self.number*2+1) + "_to_k_ip" self.v_key = str(self.number*2+1) + "_to_v_ip" @@ -72,12 +72,6 @@ def __call__(self, q, k, v, extra_options): weight = weight * 0.2 elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'): weight = weight * 0.2 - elif weight_type.startswith('style transfer'): - if t_idx != 6: - continue - elif weight_type.startswith('composition'): - if t_idx != 3: - continue elif isinstance(weight, dict): if t_idx not in weight: continue