Skip to content

Commit

Permalink
⚡ Real multi-inputs unit (AUTOMATIC1111#2578)
Browse files Browse the repository at this point in the history
* Add adapter tests

* real multi inputs
  • Loading branch information
huchenlei authored Jan 25, 2024
1 parent 918da1b commit 6a1d882
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 45 deletions.
8 changes: 7 additions & 1 deletion internal_controlnet/external_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,13 @@ def __eq__(self, other):

def accepts_multiple_inputs(self) -> bool:
"""This unit can accept multiple input images."""
return False
return self.module in (
"ip-adapter_clip_sdxl",
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_clip_sd15",
"ip-adapter_face_id",
"ip-adapter_face_id_plus",
)


def to_base64_nparray(encoding: str):
Expand Down
75 changes: 50 additions & 25 deletions scripts/controlmodel_ipadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch.nn as nn
from transformers.models.clip.modeling_clip import CLIPVisionModelOutput
from einops import rearrange
from scripts.logging import logger


Expand Down Expand Up @@ -331,36 +331,56 @@ def load_ip_adapter(self, state_dict):
self.ip_layers = To_KV(state_dict["ip_adapter"])

@torch.inference_mode()
def get_image_embeds(self, clip_vision_output: CLIPVisionModelOutput):
self.image_proj_model.cpu()

def get_image_embeds(self, clip_vision_outputs):
self.image_proj_model.to(self.device)
clip_vision_outputs = clip_vision_outputs if isinstance(clip_vision_outputs, (list, tuple)) else [clip_vision_outputs]
if self.is_plus:
clip_embeds = torch.cat([
clip_vision_output['hidden_states'][-2]
for clip_vision_output in clip_vision_outputs
], dim=0).to(device=self.device, dtype=torch.float32)

from annotator.clipvision import clip_vision_h_uc, clip_vision_vith_uc
cond = self.image_proj_model(clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32))
uncond = clip_vision_vith_uc.to(cond) if self.sdxl_plus else self.image_proj_model(clip_vision_h_uc.to(cond))
cond = self.image_proj_model(clip_embeds)
if self.sdxl_plus:
uncond = clip_vision_vith_uc.to(cond)
else:
uncond = self.image_proj_model(clip_vision_h_uc.to(cond))
return cond, uncond

clip_image_embeds = clip_vision_output['image_embeds'].to(device='cpu', dtype=torch.float32)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
# input zero vector for unconditional.
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
else:
clip_image_embeds = torch.cat([
clip_vision_output['image_embeds']
for clip_vision_output in clip_vision_outputs
], dim=0).to(device=self.device, dtype=torch.float32)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds

@torch.inference_mode()
def get_image_embeds_faceid_plus(self, face_embed, clip_vision_output: CLIPVisionModelOutput, is_v2: bool):
face_embed = face_embed.to(self.device, dtype=torch.float32)
def get_image_embeds_faceid_plus(self, insightface_outputs, clip_vision_outputs, is_v2: bool):
self.image_proj_model.to(self.device)
faceid_embed_list = insightface_outputs
clip_vision_outputs = clip_vision_outputs if isinstance(clip_vision_outputs, (list, tuple)) else [clip_vision_outputs]
clip_embed_list = [
clip_vision_output['hidden_states'][-2]
for clip_vision_output in clip_vision_outputs
]
conds = []
unconds = []
from annotator.clipvision import clip_vision_h_uc
clip_embed = clip_vision_output['hidden_states'][-2].to(device=self.device, dtype=torch.float32)
return (
self.image_proj_model(face_embed, clip_embed, shortcut=is_v2),
self.image_proj_model(torch.zeros_like(face_embed), clip_vision_h_uc.to(clip_embed), shortcut=is_v2),
)
for faceid_embed, clip_embed in zip(faceid_embed_list, clip_embed_list):
faceid_embed = faceid_embed.to(device=self.device, dtype=torch.float32)
clip_embed = clip_embed.to(device=self.device, dtype=torch.float32)
conds.append(self.image_proj_model(faceid_embed, clip_embed, shortcut=is_v2))
unconds.append(self.image_proj_model(torch.zeros_like(faceid_embed), clip_vision_h_uc.to(clip_embed), shortcut=is_v2))

return torch.cat(conds, dim=0), torch.cat(unconds, dim=0)

@torch.inference_mode()
def get_image_embeds_faceid(self, insightface_outputs: List[torch.Tensor]):
"""Get image embeds for non-plus faceid. Multiple inputs are supported."""
self.image_proj_model.to(self.device)
batch_size = len(insightface_outputs)

faceid_embeds = torch.cat(insightface_outputs, dim=0).to(self.device, dtype=torch.float32)
assert faceid_embeds.ndim == 2
image_prompt_embeds = self.image_proj_model(faceid_embeds)
Expand Down Expand Up @@ -537,8 +557,9 @@ def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float3
else:
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output)

self.image_emb = self.image_emb.to(device, dtype=self.dtype)
self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype)
assert self.image_emb.ndim == self.uncond_image_emb.ndim == 3
self.image_emb = self.image_emb.to(device, dtype=self.dtype).unsqueeze(0)
self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype).unsqueeze(0)

# From https://github.com/laksjdjf/IPAdapter-ComfyUI
if not self.sdxl:
Expand Down Expand Up @@ -580,6 +601,7 @@ def call_ip(self, key: str, feat, device):
def patch_forward(self, number: int):
@torch.no_grad()
def forward(attn_blk, x, q):
emb_size = self.image_emb.shape[1]
batch_size, sequence_length, inner_dim = x.shape
h = attn_blk.heads
head_dim = inner_dim // h
Expand All @@ -588,15 +610,18 @@ def forward(attn_blk, x, q):
if current_sampling_percent < self.p_start or current_sampling_percent > self.p_end:
return 0

cond_mark = current_model.cond_mark[:, :, :, 0].to(self.image_emb)
cond_mark = current_model.cond_mark.to(self.image_emb)
cond_uncond_image_emb = self.image_emb * cond_mark + self.uncond_image_emb * (1 - cond_mark)
k_key = f"{number * 2 + 1}_to_k_ip"
v_key = f"{number * 2 + 1}_to_v_ip"
ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device)
ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device)

ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
lambda t: rearrange(
t, "batch emb key (head head_dim) -> emb batch head key head_dim",
batch=batch_size, emb=emb_size, head=h, head_dim=head_dim,
),
(ip_k, ip_v),
)
assert ip_k.dtype == ip_v.dtype
Expand All @@ -608,7 +633,7 @@ def forward(attn_blk, x, q):
ip_v = ip_v.to(dtype=q.dtype)

ip_out = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
ip_out = ip_out.mean(dim=0).transpose(1, 2).reshape(batch_size, -1, h * head_dim)

return ip_out * self.weight
return forward
1 change: 0 additions & 1 deletion scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def unfold_merged(self) -> List[external_code.ControlNetUnit]:
unit = copy(self)
unit.image = image["image"]
unit.input_mode = InputMode.SIMPLE
unit.weight = 1 / len(self.image)
result.append(unit)
return result

Expand Down
36 changes: 21 additions & 15 deletions scripts/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,23 @@ def unload_pidinet():
}


def clip(img, res=512, config='clip_vitl', low_vram=False, **kwargs):
img = HWC3(img)
global clip_encoder
if clip_encoder[config] is None:
from annotator.clipvision import ClipVisionDetector
if low_vram:
logger.info("Loading CLIP model on CPU.")
clip_encoder[config] = ClipVisionDetector(config, low_vram)
result = clip_encoder[config](img)
return result, False
def clip(imgs: Union[Tuple[np.ndarray], np.ndarray], config='clip_vitl', low_vram=False, **kwargs):
imgs = imgs if isinstance(imgs, tuple) else (imgs,)
result = []
for img in imgs:
img = HWC3(img)
global clip_encoder
if clip_encoder[config] is None:
from annotator.clipvision import ClipVisionDetector
if low_vram:
logger.info("Loading CLIP model on CPU.")
clip_encoder[config] = ClipVisionDetector(config, low_vram)
result.append(clip_encoder[config](img))

if len(result) == 1:
return result[0], False
else:
return result, False


def unload_clip(config='clip_vitl'):
Expand Down Expand Up @@ -740,12 +747,11 @@ def run_model(self, imgs: Union[Tuple[np.ndarray], np.ndarray], **kwargs):
g_insight_face_model = InsightFaceModel()


def face_id_plus(img, low_vram=False, **kwargs):
def face_id_plus(imgs: Union[Tuple[np.ndarray], np.ndarray], low_vram=False, **kwargs):
""" FaceID plus uses both face_embeding from insightface and clip_embeding from clip. """
face_embed, _ = g_insight_face_model.run_model(img)
clip_embed, _ = clip(img, config='clip_h', low_vram=low_vram)
assert len(face_embed) > 0
return (face_embed[0], clip_embed), False
face_embed, _ = g_insight_face_model.run_model(imgs)
clip_embed, _ = clip(imgs, config='clip_h', low_vram=low_vram)
return (face_embed, clip_embed), False


class HandRefinerModel:
Expand Down
141 changes: 139 additions & 2 deletions tests/web_api/full_coverage/ipadapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,143 @@ def lora_prompt(self) -> str:
"no_neg": "",
}


sd15_full_face = AdapterSetting(
"ip-adapter_clip_sd15",
"ip-adapter-full-face_sd15 [852b9843]",
)
sd15_plus_face = AdapterSetting(
"ip-adapter_clip_sd15",
"ip-adapter-plus-face_sd15 [71693645]",
)
sd15_normal = AdapterSetting(
"ip-adapter_clip_sd15",
"ip-adapter_sd15 [6a3f6166]",
)
sd15_light = AdapterSetting(
"ip-adapter_clip_sd15",
"ip-adapter_sd15_light [be1c9b97]",
)
sdxl_normal = AdapterSetting(
"ip-adapter_clip_sdxl",
"ip-adapter_sdxl [d5d53548]"
)
sdxl_vit = AdapterSetting(
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter_sdxl_vit-h [75a08f84]",
)
sdxl_plus_vit = AdapterSetting(
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter-plus_sdxl_vit-h [f1f19f7d]",
)
sdxl_plus_vit_face = AdapterSetting(
"ip-adapter_clip_sdxl_plus_vith",
"ip-adapter-plus-face_sdxl_vit-h [c60d7d48]",
)
class TestIPAdapterFullCoverage(unittest.TestCase):
def setUp(self):
if not is_full_coverage:
pytest.skip()

if sd_version == StableDiffusionVersion.SDXL:
self.settings = [
sdxl_normal,
sdxl_vit,
sdxl_plus_vit,
sdxl_plus_vit_face,
]
else:
self.settings = [
sd15_normal,
sd15_light,
sd15_plus_face,
sd15_full_face,
]

def test_adapter(self):
for s in self.settings:
for n, negative_prompt in negative_prompts.items():
name = f"{s}_{n}"
with self.subTest(name=name):
self.assertTrue(
APITestTemplate(
name,
"txt2img",
payload_overrides={
"prompt": f"{base_prompt},{s.lora_prompt}",
"negative_prompt": negative_prompt,
"steps": 20,
"width": 512,
"height": 512,
},
unit_overrides=[
{
"module": s.module,
"model": s.model,
"image": realistic_girl_face_img,
},
openpose_unit,
],
).exec()
)

def test_adapter_multi_inputs(self):
for s in self.settings:
for n, negative_prompt in negative_prompts.items():
name = f"multi_inputs_{s}_{n}"
with self.subTest(name=name):
self.assertTrue(
APITestTemplate(
name=name,
gen_type="txt2img",
payload_overrides={
"prompt": f"{base_prompt}, {s.lora_prompt}",
"negative_prompt": negative_prompt,
"steps": 20,
"width": 512,
"height": 512,
},
unit_overrides=[openpose_unit]
+ [
{
"image": img,
"module": s.module,
"model": s.model,
"weight": 1 / len(portrait_imgs),
}
for img in portrait_imgs
],
).exec()
)

def test_adapter_real_multi_inputs(self):
for s in self.settings:
for n, negative_prompt in negative_prompts.items():
name = f"real_multi_{s}_{n}"
with self.subTest(name=name):
self.assertTrue(
APITestTemplate(
name=name,
gen_type="txt2img",
payload_overrides={
"prompt": f"{base_prompt}, {s.lora_prompt}",
"negative_prompt": negative_prompt,
"steps": 20,
"width": 512,
"height": 512,
},
unit_overrides=[
openpose_unit,
{
"image": [{"image": img} for img in portrait_imgs],
"module": s.module,
"model": s.model,
},
],
).exec()
)


sd15_face_id = AdapterSetting(
"ip-adapter_face_id",
"ip-adapter-faceid_sd15 [0a1757e9]",
Expand All @@ -66,7 +203,7 @@ def lora_prompt(self) -> str:
)


class TestIPAdapterFullCoverage(unittest.TestCase):
class TestIPAdapterFaceIdFullCoverage(unittest.TestCase):
def setUp(self):
if not is_full_coverage:
pytest.skip()
Expand Down Expand Up @@ -138,7 +275,7 @@ def test_face_id_multi_inputs(self):
)

def test_face_id_real_multi_inputs(self):
for s in (sd15_face_id, sd15_face_id_portrait):
for s in self.settings:
for n, negative_prompt in negative_prompts.items():
name = f"real_multi_{s}_{n}"
with self.subTest(name=name):
Expand Down
4 changes: 3 additions & 1 deletion tests/web_api/full_coverage/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def expect_same_image(img1, img2, diff_img_path: str) -> bool:
# Save the diff_highlighted image to inspect the differences
cv2.imwrite(diff_img_path, diff_highlighted)

return similar
matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1)
similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95
return similar_in_general


default_unit = {
Expand Down

0 comments on commit 6a1d882

Please sign in to comment.