Skip to content

Commit

Permalink
testing
Browse files Browse the repository at this point in the history
  • Loading branch information
matt3o committed Jan 27, 2024
1 parent 60295a2 commit 1ec20b6
Show file tree
Hide file tree
Showing 5 changed files with 1,162 additions and 1 deletion.
355 changes: 355 additions & 0 deletions InstantID.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
import torch
import os
import comfy.utils
import folder_paths
import numpy as np
import math
import cv2
import PIL.Image
from comfy.ldm.modules.attention import optimized_attention
from .resampler import Resampler

from insightface.app import FaceAnalysis

import torchvision.transforms.v2 as T

MODELS_DIR = os.path.join(folder_paths.models_dir, "instantid")
if "instantid" not in folder_paths.folder_names_and_paths:
current_paths = [MODELS_DIR]
else:
current_paths, _ = folder_paths.folder_names_and_paths["instantid"]
folder_paths.folder_names_and_paths["instantid"] = (current_paths, folder_paths.supported_pt_extensions)

INSIGHTFACE_DIR = os.path.join(folder_paths.models_dir, "insightface")

def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
stickwidth = 4
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
kps = np.array(kps)

h, w, _ = image_pil.shape
out_img = np.zeros([h, w, 3])

for i in range(len(limbSeq)):
index = limbSeq[i]
color = color_list[index[0]]

x = kps[index][:, 0]
y = kps[index][:, 1]
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
out_img = (out_img * 0.6).astype(np.uint8)

for idx_kp, kp in enumerate(kps):
color = color_list[idx_kp]
x, y = kp
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)

out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
return out_img_pil

def set_model_patch_replace(model, patch_kwargs, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
if key not in to["patches_replace"]["attn2"]:
patch = CrossAttentionPatch(**patch_kwargs)
to["patches_replace"]["attn2"][key] = patch
else:
to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs)


class CrossAttentionPatch:
# forward for patching
def __init__(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0):
self.weights = [weight]
self.instantid = [instantid]
self.conds = [cond]
self.unconds = [uncond]
self.number = number
self.masks = [mask]
self.sigma_start = [sigma_start]
self.sigma_end = [sigma_end]

self.k_key = str(self.number*2+1) + "_to_k_ip"
self.v_key = str(self.number*2+1) + "_to_v_ip"

def set_new_condition(self, weight, instantid, number, cond, uncond, mask=None, sigma_start=0.0, sigma_end=1.0):
self.weights.append(weight)
self.instantid.append(instantid)
self.conds.append(cond)
self.unconds.append(uncond)
self.masks.append(mask)
self.sigma_start.append(sigma_start)
self.sigma_end.append(sigma_end)

def __call__(self, n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
cond_or_uncond = extra_options["cond_or_uncond"]
sigma = extra_options["sigmas"][0].item() if 'sigmas' in extra_options else 999999999.9

q = n
k = context_attn2
v = value_attn2
b = q.shape[0]
qs = q.shape[1]
batch_prompt = b // len(cond_or_uncond)
out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, lh, lw = extra_options["original_shape"]

for weight, cond, uncond, instantid, mask, sigma_start, sigma_end in zip(self.weights, self.conds, self.unconds, self.instantid, self.masks, self.sigma_start, self.sigma_end):
#if sigma > sigma_start or sigma < sigma_end:
# continue

k_cond = instantid.ip_layers.to_kvs[self.k_key](cond).repeat(b, 1, 1)
k_uncond = instantid.ip_layers.to_kvs[self.k_key](uncond).repeat(batch_prompt, 1, 1)
v_cond = instantid.ip_layers.to_kvs[self.v_key](cond).repeat(b, 1, 1)
v_uncond = instantid.ip_layers.to_kvs[self.v_key](uncond).repeat(batch_prompt, 1, 1)

ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0)
ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0)

out_iid = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
out_iid = out_iid * weight

out = out + out_iid

return out.to(dtype=org_dtype)


class InstantID(torch.nn.Module):
def __init__(self, instantid_model, cross_attention_dim=1024, output_cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()

self.clip_embeddings_dim = clip_embeddings_dim
self.cross_attention_dim = cross_attention_dim
self.output_cross_attention_dim = output_cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens

self.image_proj_model = self.init_proj()

self.image_proj_model.load_state_dict(instantid_model["image_proj"])
self.ip_layers = To_KV(instantid_model["ip_adapter"])

def init_proj(self):
image_proj_model = Resampler(
dim=self.cross_attention_dim,
depth=4,
dim_head=64,
heads=20,
num_queries=self.clip_extra_context_tokens,
embedding_dim=self.clip_embeddings_dim,
output_dim=self.output_cross_attention_dim,
ff_mult=4
)
return image_proj_model

@torch.inference_mode()
def get_image_embeds(self, clip_embed, clip_embed_zeroed):
image_prompt_embeds = clip_embed.clone().detach()
image_prompt_embeds = self.image_proj_model(image_prompt_embeds)
#image_prompt_embeds = image_prompt_embeds.reshape([1, -1, 512])

uncond_image_prompt_embeds = clip_embed_zeroed.clone().detach()
uncond_image_prompt_embeds = self.image_proj_model(uncond_image_prompt_embeds)
#uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape([1, -1, 512])

return image_prompt_embeds, uncond_image_prompt_embeds

class ImageProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()

self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)

def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens

class To_KV(torch.nn.Module):
def __init__(self, state_dict):
super().__init__()

self.to_kvs = torch.nn.ModuleDict()
for key, value in state_dict.items():
self.to_kvs[key.replace(".weight", "").replace(".", "_")] = torch.nn.Linear(value.shape[1], value.shape[0], bias=False)
self.to_kvs[key.replace(".weight", "").replace(".", "_")].weight.data = value

def set_model_patch_replace(model, patch_kwargs, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
if key not in to["patches_replace"]["attn2"]:
patch = CrossAttentionPatch(**patch_kwargs)
to["patches_replace"]["attn2"][key] = patch
else:
to["patches_replace"]["attn2"][key].set_new_condition(**patch_kwargs)


class InstantIDModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "instantid_file": (folder_paths.get_filename_list("instantid"), )}}

RETURN_TYPES = ("INSTANTID",)
FUNCTION = "load_model"
CATEGORY = "InstantID"

def load_model(self, instantid_file):
ckpt_path = folder_paths.get_full_path("instantid", instantid_file)

model = comfy.utils.load_torch_file(ckpt_path, safe_load=True)

if ckpt_path.lower().endswith(".safetensors"):
st_model = {"image_proj": {}, "ip_adapter": {}}
for key in model.keys():
if key.startswith("image_proj."):
st_model["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
model = st_model

return (model,)

class InsightFaceLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"provider": (["CPU", "CUDA", "ROCM"], ),
},
}

RETURN_TYPES = ("INSIGHTFACE",)
FUNCTION = "load_insight_face"
CATEGORY = "InstantID"

def load_insight_face(self, provider):
model = FaceAnalysis(name="antelopev2", root=INSIGHTFACE_DIR, providers=[provider + 'ExecutionProvider',]) # buffalo_l
model.prepare(ctx_id=0, det_size=(640, 640))

return (model,)

def tensorToNP(image):
out = torch.clamp(255. * image.detach().cpu(), 0, 255).to(torch.uint8)
out = out[..., [2, 1, 0]]
out = out.numpy()

return out

class ApplyInstantID:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"instantid": ("INSTANTID", ),
"insightface": ("INSIGHTFACE", ),
"model": ("MODEL", ),
"image": ("IMAGE", )
},
}

RETURN_TYPES = ("MODEL", "IMAGE")
RETURN_NAMES = ("MODEL", "IMAGE_KPS")
FUNCTION = "apply_instantid"
CATEGORY = "InstantID"

def apply_instantid(self, instantid, insightface, model, image):
self.dtype = torch.float16 if comfy.model_management.should_use_fp16() else torch.float32
self.device = comfy.model_management.get_torch_device()
self.weight = 1.0

output_cross_attention_dim = instantid["ip_adapter"]["1.to_k_ip.weight"].shape[1]
cross_attention_dim = 1280
clip_extra_context_tokens = 16

insightface.det_model.input_size = (640,640) # reset the detection size
face_img = tensorToNP(image)
face_embed = []
face_kps = []

for i in range(face_img.shape[0]):
for size in [(size, size) for size in range(640, 128, -64)]:
insightface.det_model.input_size = size # TODO: hacky but seems to be working
face = insightface.get(face_img[i])
if face:
face_embed.append(torch.from_numpy(face[0].embedding).unsqueeze(0))
face_kps.append(draw_kps(face_img[i], face[0].kps))

if 640 not in size:
print(f"\033[33mINFO: InsightFace detection resolution lowered to {size}.\033[0m")
break
else:
raise Exception('InsightFace: No face detected.')

face_embed = torch.stack(face_embed, dim=0)
face_kps = torch.stack(T.ToTensor()(face_kps), dim=0).permute([0,2,3,1])

clip_embed = face_embed
clip_embed_zeroed = torch.zeros_like(clip_embed)

clip_embeddings_dim = face_embed.shape[-1]

self.instantid = InstantID(
instantid,
cross_attention_dim=cross_attention_dim,
output_cross_attention_dim=output_cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=clip_extra_context_tokens,
)

self.instantid.to(self.device, dtype=self.dtype)

image_prompt_embeds, uncond_image_prompt_embeds = self.instantid.get_image_embeds(clip_embed.to(self.device, dtype=self.dtype), clip_embed_zeroed.to(self.device, dtype=self.dtype))

image_prompt_embeds = image_prompt_embeds.to(self.device, dtype=self.dtype)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device, dtype=self.dtype)

work_model = model.clone()

patch_kwargs = {
"number": 0,
"weight": self.weight,
"instantid": self.instantid,
"cond": image_prompt_embeds,
"uncond": uncond_image_prompt_embeds,
}

for id in [4,5,7,8]: # id of input_blocks that have cross attention
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(work_model, patch_kwargs, ("input", id, index))
patch_kwargs["number"] += 1
for id in range(6): # id of output_blocks that have cross attention
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(work_model, patch_kwargs, ("output", id, index))
patch_kwargs["number"] += 1
for index in range(10):
set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index))
patch_kwargs["number"] += 1

return(work_model, face_kps, )

NODE_CLASS_MAPPINGS = {
"InstantIDModelLoader": InstantIDModelLoader,
"InsightFaceLoaderIID": InsightFaceLoader,
"ApplyInstantID": ApplyInstantID,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"InstantIDModelLoader": "Load InstantID Model",
"InsightFaceLoaderIID": "Load InsightFace IID",
"ApplyInstantID": "Apply InstantID",
}
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
# ComfyUI_InstantID
## NOT WORKING YET!! do not use

Initial work to support [InstandID](https://github.com/InstantID/InstantID) natively in ComfyUI.

This is mostly a placeholder, more work is needed... if I get the time.

Model go in ComfyUI/models/instantid, you need "antelopev2" models for insightface.

This repo is temporary and might be removed.
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .InstantID import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
Loading

0 comments on commit 1ec20b6

Please sign in to comment.