Skip to content

Commit

Permalink
add combine embeds option
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Mar 29, 2024
1 parent e6477fe commit ac20ccc
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions InstantID.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import math
import cv2
import PIL.Image
from comfy.ldm.modules.attention import optimized_attention
from .resampler import Resampler
from .CrossAttentionPatch import CrossAttentionPatch
from .utils import tensor_to_size, tensor_to_image, image_to_tensor
from .utils import tensor_to_image

from insightface.app import FaceAnalysis

Expand Down Expand Up @@ -96,12 +95,12 @@ def get_image_embeds(self, clip_embed, clip_embed_zeroed):
class ImageProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()

self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)

def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
Expand Down Expand Up @@ -151,7 +150,7 @@ def load_model(self, instantid_file):
elif key.startswith("ip_adapter."):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
model = st_model

return (model,)

def extractFeatures(insightface, image, extract_kps=False):
Expand Down Expand Up @@ -234,7 +233,7 @@ def add_noise(image, factor):
mask = (torch.rand_like(image) < factor).float()
noise = torch.rand_like(image)
noise = torch.zeros_like(image) * (1-mask) + noise * mask

return factor*noise

class ApplyInstantID:
Expand Down Expand Up @@ -264,7 +263,7 @@ def INPUT_TYPES(s):
FUNCTION = "apply_instantid"
CATEGORY = "InstantID"

def apply_instantid(self, instantid, insightface, control_net, image, model, positive, negative, start_at, end_at, weight=.8, ip_weight=None, cn_strength=None, noise=0.35, image_kps=None, mask=None):
def apply_instantid(self, instantid, insightface, control_net, image, model, positive, negative, start_at, end_at, weight=.8, ip_weight=None, cn_strength=None, noise=0.35, image_kps=None, mask=None, combine_embeds='average'):
self.dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32
self.device = comfy.model_management.get_torch_device()

Expand All @@ -290,7 +289,10 @@ def apply_instantid(self, instantid, insightface, control_net, image, model, pos
clip_embed = face_embed
# InstantID works better with averaged embeds (TODO: needs testing)
if clip_embed.shape[0] > 1:
clip_embed = torch.mean(clip_embed, dim=0).unsqueeze(0)
if combine_embeds == 'average':
clip_embed = torch.mean(clip_embed, dim=0).unsqueeze(0)
elif combine_embeds == 'norm average':
clip_embed = torch.mean(clip_embed / torch.norm(clip_embed, dim=0, keepdim=True), dim=0).unsqueeze(0)

if noise > 0:
seed = int(torch.sum(clip_embed).item()) % 1000000007
Expand Down Expand Up @@ -413,6 +415,7 @@ def INPUT_TYPES(s):
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, }),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001, }),
"noise": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.1, }),
"combine_embeds": (['average', 'norm average', 'concat'], {"default": 'average'}),
},
"optional": {
"image_kps": ("IMAGE",),
Expand Down Expand Up @@ -443,7 +446,7 @@ def INPUT_TYPES(s):
FUNCTION = "patch_attention"
CATEGORY = "InstantID"

def patch_attention(self, instantid, insightface, image, model, weight, start_at, end_at, noise=0.0, mask=None):
def patch_attention(self, instantid, insightface, image, model, weight, start_at, end_at, noise=0.0, mask=None):
self.dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32
self.device = comfy.model_management.get_torch_device()

Expand Down Expand Up @@ -530,7 +533,7 @@ def patch_attention(self, instantid, insightface, image, model, weight, start_at
for index in range(10):
_set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index))
patch_kwargs["number"] += 1

return(work_model, { "cond": image_prompt_embeds, "uncond": uncond_image_prompt_embeds }, )

class ApplyInstantIDControlNet:
Expand Down Expand Up @@ -565,7 +568,7 @@ def apply_controlnet(self, face_embeds, control_net, image_kps, positive, negati

if mask is not None:
mask = mask.to(self.device)

if mask is not None and len(mask.shape) < 3:
mask = mask.unsqueeze(0)

Expand Down

0 comments on commit ac20ccc

Please sign in to comment.