Skip to content

Commit

Permalink
Merge pull request cubiq#185 from JettHu/main
Browse files Browse the repository at this point in the history
improve instantID load time
  • Loading branch information
cubiq authored Jul 10, 2024
2 parents d8c70a0 + 63fa54a commit 2c1a6b2
Showing 1 changed file with 11 additions and 31 deletions.
42 changes: 11 additions & 31 deletions InstantID.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,15 @@ def load_model(self, instantid_file):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
model = st_model

return (model,)
instant_id_model = InstantID(
model,
cross_attention_dim=1280,
output_cross_attention_dim=model["ip_adapter"]["1.to_k_ip.weight"].shape[1],
clip_embeddings_dim=512,
clip_extra_context_tokens=16,
)

return (instant_id_model,)

def extractFeatures(insightface, image, extract_kps=False):
face_img = tensor_to_image(image)
Expand Down Expand Up @@ -277,11 +285,6 @@ def apply_instantid(self, instantid, insightface, control_net, image, model, pos
ip_weight = weight if ip_weight is None else ip_weight
cn_strength = weight if cn_strength is None else cn_strength

output_cross_attention_dim = instantid["ip_adapter"]["1.to_k_ip.weight"].shape[1]
is_sdxl = output_cross_attention_dim == 2048
cross_attention_dim = 1280
clip_extra_context_tokens = 16

face_embed = extractFeatures(insightface, image)
if face_embed is None:
raise Exception('Reference Image: No face detected.')
Expand Down Expand Up @@ -309,17 +312,8 @@ def apply_instantid(self, instantid, insightface, control_net, image, model, pos
else:
clip_embed_zeroed = torch.zeros_like(clip_embed)

clip_embeddings_dim = face_embed.shape[-1]

# 1: patch the attention
self.instantid = InstantID(
instantid,
cross_attention_dim=cross_attention_dim,
output_cross_attention_dim=output_cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=clip_extra_context_tokens,
)

self.instantid = instantid
self.instantid.to(self.device, dtype=self.dtype)

image_prompt_embeds, uncond_image_prompt_embeds = self.instantid.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype))
Expand Down Expand Up @@ -451,11 +445,6 @@ def patch_attention(self, instantid, insightface, image, model, weight, start_at
self.dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32
self.device = comfy.model_management.get_torch_device()

output_cross_attention_dim = instantid["ip_adapter"]["1.to_k_ip.weight"].shape[1]
is_sdxl = output_cross_attention_dim == 2048
cross_attention_dim = 1280
clip_extra_context_tokens = 16

face_embed = extractFeatures(insightface, image)
if face_embed is None:
raise Exception('Reference Image: No face detected.')
Expand All @@ -472,17 +461,8 @@ def patch_attention(self, instantid, insightface, image, model, weight, start_at
else:
clip_embed_zeroed = torch.zeros_like(clip_embed)

clip_embeddings_dim = face_embed.shape[-1]

# 1: patch the attention
self.instantid = InstantID(
instantid,
cross_attention_dim=cross_attention_dim,
output_cross_attention_dim=output_cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=clip_extra_context_tokens,
)

self.instantid = instantid
self.instantid.to(self.device, dtype=self.dtype)

image_prompt_embeds, uncond_image_prompt_embeds = self.instantid.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype))
Expand Down

0 comments on commit 2c1a6b2

Please sign in to comment.