Skip to content

Commit c2a12e1

Browse files
author
somanchiu
committed
init
1 parent b699ec8 commit c2a12e1

22 files changed

+1549
-0
lines changed

INSwapper.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import numpy as np
2+
import onnxruntime
3+
import cv2
4+
import onnx
5+
from onnx import numpy_helper
6+
from insightface.utils import face_align
7+
8+
class INSwapper():
9+
def __init__(self, model_file=None, session=None):
10+
self.model_file = model_file
11+
self.session = session
12+
model = onnx.load(self.model_file)
13+
graph = model.graph
14+
self.emap = numpy_helper.to_array(graph.initializer[-1])
15+
16+
# emapFile = f'training/dataset/v6/emap.npy'
17+
# np.save(emapFile, self.emap)
18+
# emap = np.load(emapFile)
19+
20+
self.input_mean = 0.0
21+
self.input_std = 255.0
22+
#print('input mean and std:', model_file, self.input_mean, self.input_std)
23+
if self.session is None:
24+
self.session = onnxruntime.InferenceSession(self.model_file, None)
25+
inputs = self.session.get_inputs()
26+
self.input_names = []
27+
for inp in inputs:
28+
self.input_names.append(inp.name)
29+
outputs = self.session.get_outputs()
30+
output_names = []
31+
for out in outputs:
32+
output_names.append(out.name)
33+
self.output_names = output_names
34+
assert len(self.output_names)==1
35+
output_shape = outputs[0].shape
36+
input_cfg = inputs[0]
37+
input_shape = input_cfg.shape
38+
self.input_shape = input_shape
39+
print('inswapper-shape:', self.input_shape)
40+
self.input_size = tuple(input_shape[2:4][::-1])
41+
42+
def forward(self, img, latent):
43+
img = (img - self.input_mean) / self.input_std
44+
pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0]
45+
return pred
46+
47+
def predict(self, blob, latent):
48+
input = {self.input_names[0]: blob, self.input_names[1]: latent}
49+
pred = self.session.run(self.output_names, input)[0]
50+
return pred
51+
52+
def test(self, img, target_face, source_face):
53+
aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
54+
blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
55+
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
56+
latent = source_face.normed_embedding.reshape((1,-1))
57+
latent = np.dot(latent, self.emap)
58+
latent /= np.linalg.norm(latent)
59+
60+
pred = self.predict(blob, latent)
61+
return pred
62+
63+
def getBlob(self, aimg):
64+
blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size,
65+
(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
66+
return blob
67+
68+
def getLatent(self, source_face):
69+
latent = source_face.normed_embedding.reshape((1,-1))
70+
latent = np.dot(latent, self.emap)
71+
latent /= np.linalg.norm(latent)
72+
73+
return latent
74+
75+
def swap(self, alignedTargetFace, source_face):
76+
latent = self.getLatent(source_face)
77+
78+
pred = self.predict(self.getBlob(alignedTargetFace), latent)
79+
#print(latent.shape, latent.dtype, pred.shape)
80+
img_fake = pred.transpose((0,2,3,1))[0]
81+
bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
82+
83+
return bgr_fake
84+
85+
def swapAndPasteBack(self, img, alignedTargetFace, alignedTargetFaceM, bgr_fake):
86+
target_img = img
87+
fake_diff = bgr_fake.astype(np.float32) - alignedTargetFace.astype(np.float32)
88+
fake_diff = np.abs(fake_diff).mean(axis=2)
89+
fake_diff[:2,:] = 0
90+
fake_diff[-2:,:] = 0
91+
fake_diff[:,:2] = 0
92+
fake_diff[:,-2:] = 0
93+
IM = cv2.invertAffineTransform(alignedTargetFaceM)
94+
img_white = np.full((alignedTargetFace.shape[0],alignedTargetFace.shape[1]), 255, dtype=np.float32)
95+
bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
96+
img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
97+
fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0)
98+
img_white[img_white>20] = 255
99+
fthresh = 10
100+
fake_diff[fake_diff<fthresh] = 0
101+
fake_diff[fake_diff>=fthresh] = 255
102+
img_mask = img_white
103+
mask_h_inds, mask_w_inds = np.where(img_mask==255)
104+
mask_h = np.max(mask_h_inds) - np.min(mask_h_inds)
105+
mask_w = np.max(mask_w_inds) - np.min(mask_w_inds)
106+
mask_size = int(np.sqrt(mask_h*mask_w))
107+
k = max(mask_size//10, 10)
108+
#k = max(mask_size//20, 6)
109+
#k = 6
110+
kernel = np.ones((k,k),np.uint8)
111+
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
112+
kernel = np.ones((2,2),np.uint8)
113+
fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1)
114+
k = max(mask_size//20, 5)
115+
#k = 3
116+
#k = 3
117+
kernel_size = (k, k)
118+
blur_size = tuple(2*i+1 for i in kernel_size)
119+
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
120+
k = 5
121+
kernel_size = (k, k)
122+
blur_size = tuple(2*i+1 for i in kernel_size)
123+
fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0)
124+
img_mask /= 255
125+
fake_diff /= 255
126+
#img_mask = fake_diff
127+
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
128+
fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32)
129+
fake_merged = fake_merged.astype(np.uint8)
130+
return fake_merged
131+
132+
def get(self, img, target_face, source_face, paste_back=True):
133+
aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0])
134+
135+
bgr_fake = self.swap(aimg, source_face)
136+
if not paste_back:
137+
return bgr_fake, M
138+
else:
139+
self.swapAndPasteBack(img, aimg, M, bgr_fake)
140+

Image.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
import cv2
3+
import numpy as np
4+
5+
emap = np.load("emap.npy")
6+
input_std = 255.0
7+
input_mean = 0.0
8+
input_size = (128, 128)
9+
10+
def postprocess_face(face_tensor):
11+
face_tensor = face_tensor.squeeze().cpu().detach()
12+
face_np = (face_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
13+
face_np = cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR)
14+
15+
return face_np
16+
17+
def getBlob(aimg):
18+
blob = cv2.dnn.blobFromImage(aimg, 1.0 / input_std, input_size,
19+
(input_mean, input_mean, input_mean), swapRB=True)
20+
return blob
21+
22+
def getLatent(source_face):
23+
latent = source_face.normed_embedding.reshape((1,-1))
24+
latent = np.dot(latent, emap)
25+
latent /= np.linalg.norm(latent)
26+
27+
return latent

README.md

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# ReSwapper
2+
3+
ReSwapper aims to reproduce the implementation of inswapper. This repository provides code for training, inference, and includes pretrained weights.
4+
5+
Here is the comparesion of the output of Inswapper and Reswapper.
6+
| Target | Source | Inswapper Output | Reswapper Output (Step 429500) |
7+
|--------|--------|--------|--------|
8+
| ![targetImg](example\1\target.jpg) |![targetImg](example\1\source.jpg) | ![targetImg](example\1\inswapperOutput.jpg) | ![targetImg](example\1\reswapperOutput.jpg) |
9+
| ![targetImg](example\2\target.jpg) |![targetImg](example\2\source.jpg) | ![targetImg](example\2\inswapperOutput.jpg) | ![targetImg](example\2\reswapperOutput.jpg) |
10+
| ![targetImg](example\3\target.jpg) |![targetImg](example\3\source.png) | ![targetImg](example\3\inswapperOutput.jpg) | ![targetImg](example\3\reswapperOutput.jpg) |
11+
12+
## Installation
13+
14+
```bash
15+
git clone https://github.com/somanchiu/ReSwapper.git
16+
cd ReSwapper
17+
python -m venv venv
18+
19+
venv\scripts\activate
20+
21+
pip install -r requirements.txt
22+
23+
pip install torch torchvision --force --index-url https://download.pytorch.org/whl/cu121
24+
pip install onnxruntime-gpu --force --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
25+
```
26+
27+
## The details of inswapper
28+
29+
### Model architecture
30+
The inswapper model architecture can be visualized in [Netron](https://netron.app). You can compare with ReSwapper implementation to see architectural similarities
31+
32+
We can also use the following Python code to get more details:
33+
```python
34+
model = onnx.load('test.onnx')
35+
printable_graph=onnx.helper.printable_graph(model.graph)
36+
```
37+
38+
### Model input
39+
- target: [1, 3, 128, 128] shape, normalized to [-1, 1] range
40+
- source (latent): [1, 512] shape, the features of the source face
41+
- Calculation of latent, "emap" can be extracted from the original inswapper model.
42+
```python
43+
latent = source_face.normed_embedding.reshape((1,-1))
44+
latent = np.dot(latent, emap)
45+
latent /= np.linalg.norm(latent)
46+
```
47+
48+
49+
### Loss Functions
50+
There is no information released from insightface. It is an important part of the training. However, there are a lot of articles and papers that can be referenced. By reading a substantial number of articles and papers on face swapping, ID fidelity, and style transfer, you'll frequently encounter the following keywords:
51+
- content loss
52+
- style loss/id loss
53+
- perceptual loss
54+
55+
## Training
56+
### 0. Pretrained weights (Optional)
57+
If you don't want to train the model from scratch, you can download the pretrained weights and pass model_path into the train function in train.py.
58+
59+
### 1. Dataset Preparation
60+
Download [FFHQ](https://www.kaggle.com/datasets/arnaud58/flickrfaceshq-dataset-ffhq) to use as target and source images. For the swaped face images, we can use the inswapper output.
61+
62+
### 2. Model Training
63+
64+
Optimizer: Adam
65+
66+
Rearning rate: 0.0001
67+
68+
Modify the code in train.py if needed. Then, execute:
69+
```python
70+
python train.py
71+
```
72+
73+
The model will be saved as "reswapper-\<total steps\>.pth".
74+
75+
## Notes
76+
- Do not stop the training too early.
77+
78+
- I'm using an RTX3060 12GB for training. It takes around 12 hours for 50,000 steps.
79+
- The optimizer may need to be changed to SGD for the final training, as many articles show that SGD can result in lower loss.
80+
81+
## Inference
82+
```python
83+
python swap.py
84+
```
85+
86+
## Pretrained Model
87+
88+
- [reswapper-429500.pth](https://huggingface.co/somanchiu/reswapper/tree/main)
89+
90+
## To Do
91+
- Create 512 resolution model
92+
- Implement face paste-back functionality
93+
- Add emap to the onnx file

StyleTransferLoss.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import cv2
2+
import torch
3+
import torch.nn as nn
4+
import numpy as np
5+
from insightface.app import FaceAnalysis
6+
from pytorch_msssim import ssim
7+
8+
class StyleTransferLoss(nn.Module):
9+
def __init__(self, device='cuda', inswapper=None):
10+
super(StyleTransferLoss, self).__init__()
11+
self.face_analysis = FaceAnalysis(providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
12+
self.face_analysis.prepare(ctx_id=0, det_size=(128, 128))
13+
self.device = device
14+
self.inswapper = inswapper
15+
self.cosine_similarity = nn.CosineSimilarity(dim=0)
16+
17+
# Content loss
18+
self.content_loss = nn.MSELoss()
19+
20+
# Style loss
21+
self.style_loss = nn.MSELoss()
22+
23+
# Face identity loss
24+
self.identity_loss = nn.CosineSimilarity(dim=1, eps=1e-6)
25+
26+
def gram_matrix(self, input):
27+
# a, b, c, d = input.size()
28+
# features = input.view(a * b, c * d)
29+
G = torch.mm(input, input.t())
30+
return G
31+
32+
def extract_face_embedding(self, image):
33+
# Convert torch tensor to numpy array
34+
face_tensor = image.squeeze().cpu().detach()
35+
# face_tensor = (face_tensor * 0.5 + 0.5).clamp(0, 1)
36+
face_np = (face_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
37+
face_np = cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR)
38+
39+
# Extract face embedding
40+
faces = self.face_analysis.get(face_np)
41+
if len(faces) == 0:
42+
return None
43+
return torch.tensor(faces[0].normed_embedding).to(self.device)
44+
45+
def extract_face_latent(self, image):
46+
# Convert torch tensor to numpy array
47+
face_tensor = image.squeeze().cpu().detach()
48+
# face_tensor = (face_tensor * 0.5 + 0.5).clamp(0, 1)
49+
face_np = (face_tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
50+
face_np = cv2.cvtColor(face_np, cv2.COLOR_RGB2BGR)
51+
52+
# Extract face embedding
53+
faces = self.face_analysis.get(face_np)
54+
if len(faces) == 0:
55+
return None
56+
return torch.tensor(self.inswapper.getLatent(faces[0])[0]).to(self.device)
57+
58+
def get_style_loss(self, latent1, latent2):
59+
# target = torch.tensor([1.0]).to("cuda")
60+
similarity = torch.dot(latent1, latent2)
61+
62+
# # Binary Cross-Entropy Loss
63+
# epsilon = 1e-7 # Small value to avoid log(0)
64+
# loss = -target * torch.log(similarity + epsilon) - (1 - target) * torch.log(1 - similarity + epsilon)
65+
66+
return 1 - similarity
67+
68+
def forward(self, output_image, target_content, target_face_latent, source_face_latent):
69+
# Content loss
70+
# content_loss = self.content_loss(output_image, target_content)
71+
content_loss = 1 - ssim(output_image, target_content, data_range=1.0)
72+
73+
# Style loss
74+
# style_loss = 0
75+
# for out_feature, style_feature in zip(output_features, style_features):
76+
# out_gram = self.gram_matrix(out_feature)
77+
# style_gram = self.gram_matrix(style_feature)
78+
# style_loss += self.style_loss(out_gram, style_gram)
79+
80+
# Face identity loss
81+
82+
output_embedding = self.extract_face_latent(output_image)
83+
target_embedding = self.extract_face_latent(target_content)
84+
85+
identity_loss = None
86+
euclidean_distance = None
87+
88+
if output_embedding is not None and target_embedding is not None:
89+
similarity = self.cosine_similarity(output_embedding, target_embedding)
90+
# similarity2 = self.cosine_similarity(output_embedding, torch.tensor(target_face_latent).to(self.device))
91+
# similarity2 = (similarity2 + 1) / 2
92+
identity_loss = 1-((similarity + 1) / 2)
93+
identity_loss = identity_loss ** 2 * 10
94+
# euclidean_distance = torch.sqrt(torch.sum((output_embedding - target_embedding) ** 2))
95+
# similarityA = self.cosine_similarity(output_embedding, output_embedding)
96+
# similarityB = self.cosine_similarity(target_embedding, target_embedding)
97+
98+
# identity_loss +=similarity2
99+
# margin = 0.2
100+
# identity_loss = nn.functional.relu(margin - similarity)
101+
102+
# target = torch.tensor([1.0]).to("cuda")
103+
# # Binary Cross-Entropy Loss
104+
# loss = -target * torch.log(similarity) - (1 - target) * torch.log(1 - similarity)
105+
106+
# identity_loss= loss.mean()
107+
# identity_loss = 1 - self.identity_loss(output_embedding.unsqueeze(0), target_embedding.unsqueeze(0)).mean()
108+
# identity_loss = self.get_style_loss(output_embedding, target_embedding)
109+
# identity_loss = self.content_loss(output_embedding, target_embedding)
110+
# identity_loss = 1 - torch.nn.functional.cosine_similarity(output_embedding, target_embedding, dim=0)
111+
# identity_loss = torch.tensor(0.0).to(self.device)
112+
113+
# Total loss (you can adjust the weights as needed)
114+
# total_loss = content_loss*0.1 + identity_loss
115+
116+
return content_loss, identity_loss, euclidean_distance
117+
118+
# Usage example:
119+
# loss_fn = StyleTransferLoss()
120+
# total_loss, content_loss, style_loss, identity_loss = loss_fn(content_features, style_features, output_image, target_content, target_style)

0 commit comments

Comments
 (0)