diff --git a/InstantID.py b/InstantID.py index 34133b5..fa25c37 100644 --- a/InstantID.py +++ b/InstantID.py @@ -123,7 +123,7 @@ def __call__(self, n, context_attn2, value_attn2, extra_options): mask_downsample = mask_downsample.repeat(len(cond_or_uncond), 1, 1) mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1, 1).repeat(1, 1, out.shape[2]) - out_ip = out_ip * mask_downsample + out_iid = out_iid * mask_downsample out = out + out_iid