Skip to content

Commit

Permalink
fix compatibility with IPAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Apr 4, 2024
1 parent ac20ccc commit 5044599
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
62 changes: 41 additions & 21 deletions CrossAttentionPatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

class CrossAttentionPatch:
# forward for patching
def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
self.weights = [weight]
self.ipadapters = [ipadapter]
self.conds = [cond]
self.conds_alt = [cond_alt]
self.unconds = [uncond]
self.weight_types = [weight_type]
self.masks = [mask]
Expand All @@ -18,15 +19,16 @@ def __init__(self, ipadapter=None, number=0, weight=1.0, cond=None, uncond=None,
self.unfold_batch = [unfold_batch]
self.embeds_scaling = [embeds_scaling]
self.number = number
self.layers = 10 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 15
self.layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 15 # TODO: check if this is a valid condition to detect all models

self.k_key = str(self.number*2+1) + "_to_k_ip"
self.v_key = str(self.number*2+1) + "_to_v_ip"

def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
def set_new_condition(self, ipadapter=None, number=0, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only'):
self.weights.append(weight)
self.ipadapters.append(ipadapter)
self.conds.append(cond)
self.conds_alt.append(cond_alt)
self.unconds.append(uncond)
self.weight_types.append(weight_type)
self.masks.append(mask)
Expand All @@ -52,8 +54,43 @@ def __call__(self, q, k, v, extra_options):
out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, oh, ow = extra_options["original_shape"]

for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, embeds_scaling in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.embeds_scaling):
for weight, cond, cond_alt, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch, embeds_scaling in zip(self.weights, self.conds, self.conds_alt, self.unconds, self.ipadapters, self.masks, self.weight_types, self.sigma_starts, self.sigma_ends, self.unfold_batch, self.embeds_scaling):
if sigma <= sigma_start and sigma >= sigma_end:
if weight_type == 'ease in':
weight = weight * (0.05 + 0.95 * (1 - t_idx / self.layers))
elif weight_type == 'ease out':
weight = weight * (0.05 + 0.95 * (t_idx / self.layers))
elif weight_type == 'ease in-out':
weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (self.layers/2)) / (self.layers/2)))
elif weight_type == 'reverse in-out':
weight = weight * (0.05 + 0.95 * (abs(t_idx - (self.layers/2)) / (self.layers/2)))
elif weight_type == 'weak input' and block_type == 'input':
weight = weight * 0.2
elif weight_type == 'weak middle' and block_type == 'middle':
weight = weight * 0.2
elif weight_type == 'weak output' and block_type == 'output':
weight = weight * 0.2
elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'):
weight = weight * 0.2
elif weight_type.startswith('style transfer'):
if t_idx != 6:
continue
elif weight_type.startswith('composition'):
if t_idx != 3:
continue
elif isinstance(weight, dict):
if t_idx not in weight:
continue

weight = weight[t_idx]

if t_idx in cond_alt:
cond = cond_alt[t_idx]
del cond_alt

if weight == 0:
continue

if unfold_batch and cond.shape[0] > 1:
# Check AnimateDiff context window
if ad_params is not None and ad_params["sub_idxs"] is not None:
Expand Down Expand Up @@ -81,23 +118,6 @@ def __call__(self, q, k, v, extra_options):
v_cond = ipadapter.ip_layers.to_kvs[self.v_key](cond).repeat(batch_prompt, 1, 1)
v_uncond = ipadapter.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1)

if weight_type == 'ease in':
weight = weight * (0.05 + 0.95 * (1 - t_idx / self.layers))
elif weight_type == 'ease out':
weight = weight * (0.05 + 0.95 * (t_idx / self.layers))
elif weight_type == 'ease in-out':
weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (self.layers/2)) / (self.layers/2)))
elif weight_type == 'reverse in-out':
weight = weight * (0.05 + 0.95 * (abs(t_idx - (self.layers/2)) / (self.layers/2)))
elif weight_type == 'weak input' and block_type == 'input':
weight = weight * 0.2
elif weight_type == 'weak middle' and block_type == 'middle':
weight = weight * 0.2
elif weight_type == 'weak output' and block_type == 'output':
weight = weight * 0.2
elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'):
weight = weight * 0.2

ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0)
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0)

Expand Down
2 changes: 1 addition & 1 deletion InstantID.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def extractFeatures(insightface, image, extract_kps=False):
insightface.det_model.input_size = size # TODO: hacky but seems to be working
face = insightface.get(face_img[i])
if face:
face = sorted(face, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
face = sorted(face, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]

if extract_kps:
out.append(draw_kps(face_img[i], face['kps']))
Expand Down

0 comments on commit 5044599

Please sign in to comment.