|
| 1 | +import numpy as np |
| 2 | +import onnxruntime |
| 3 | +import cv2 |
| 4 | +import onnx |
| 5 | +from onnx import numpy_helper |
| 6 | +from insightface.utils import face_align |
| 7 | + |
| 8 | +class INSwapper(): |
| 9 | + def __init__(self, model_file=None, session=None): |
| 10 | + self.model_file = model_file |
| 11 | + self.session = session |
| 12 | + model = onnx.load(self.model_file) |
| 13 | + graph = model.graph |
| 14 | + self.emap = numpy_helper.to_array(graph.initializer[-1]) |
| 15 | + |
| 16 | + # emapFile = f'training/dataset/v6/emap.npy' |
| 17 | + # np.save(emapFile, self.emap) |
| 18 | + # emap = np.load(emapFile) |
| 19 | + |
| 20 | + self.input_mean = 0.0 |
| 21 | + self.input_std = 255.0 |
| 22 | + #print('input mean and std:', model_file, self.input_mean, self.input_std) |
| 23 | + if self.session is None: |
| 24 | + self.session = onnxruntime.InferenceSession(self.model_file, None) |
| 25 | + inputs = self.session.get_inputs() |
| 26 | + self.input_names = [] |
| 27 | + for inp in inputs: |
| 28 | + self.input_names.append(inp.name) |
| 29 | + outputs = self.session.get_outputs() |
| 30 | + output_names = [] |
| 31 | + for out in outputs: |
| 32 | + output_names.append(out.name) |
| 33 | + self.output_names = output_names |
| 34 | + assert len(self.output_names)==1 |
| 35 | + output_shape = outputs[0].shape |
| 36 | + input_cfg = inputs[0] |
| 37 | + input_shape = input_cfg.shape |
| 38 | + self.input_shape = input_shape |
| 39 | + print('inswapper-shape:', self.input_shape) |
| 40 | + self.input_size = tuple(input_shape[2:4][::-1]) |
| 41 | + |
| 42 | + def forward(self, img, latent): |
| 43 | + img = (img - self.input_mean) / self.input_std |
| 44 | + pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] |
| 45 | + return pred |
| 46 | + |
| 47 | + def predict(self, blob, latent): |
| 48 | + input = {self.input_names[0]: blob, self.input_names[1]: latent} |
| 49 | + pred = self.session.run(self.output_names, input)[0] |
| 50 | + return pred |
| 51 | + |
| 52 | + def test(self, img, target_face, source_face): |
| 53 | + aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) |
| 54 | + blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, |
| 55 | + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) |
| 56 | + latent = source_face.normed_embedding.reshape((1,-1)) |
| 57 | + latent = np.dot(latent, self.emap) |
| 58 | + latent /= np.linalg.norm(latent) |
| 59 | + |
| 60 | + pred = self.predict(blob, latent) |
| 61 | + return pred |
| 62 | + |
| 63 | + def getBlob(self, aimg): |
| 64 | + blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, |
| 65 | + (self.input_mean, self.input_mean, self.input_mean), swapRB=True) |
| 66 | + return blob |
| 67 | + |
| 68 | + def getLatent(self, source_face): |
| 69 | + latent = source_face.normed_embedding.reshape((1,-1)) |
| 70 | + latent = np.dot(latent, self.emap) |
| 71 | + latent /= np.linalg.norm(latent) |
| 72 | + |
| 73 | + return latent |
| 74 | + |
| 75 | + def swap(self, alignedTargetFace, source_face): |
| 76 | + latent = self.getLatent(source_face) |
| 77 | + |
| 78 | + pred = self.predict(self.getBlob(alignedTargetFace), latent) |
| 79 | + #print(latent.shape, latent.dtype, pred.shape) |
| 80 | + img_fake = pred.transpose((0,2,3,1))[0] |
| 81 | + bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] |
| 82 | + |
| 83 | + return bgr_fake |
| 84 | + |
| 85 | + def swapAndPasteBack(self, img, alignedTargetFace, alignedTargetFaceM, bgr_fake): |
| 86 | + target_img = img |
| 87 | + fake_diff = bgr_fake.astype(np.float32) - alignedTargetFace.astype(np.float32) |
| 88 | + fake_diff = np.abs(fake_diff).mean(axis=2) |
| 89 | + fake_diff[:2,:] = 0 |
| 90 | + fake_diff[-2:,:] = 0 |
| 91 | + fake_diff[:,:2] = 0 |
| 92 | + fake_diff[:,-2:] = 0 |
| 93 | + IM = cv2.invertAffineTransform(alignedTargetFaceM) |
| 94 | + img_white = np.full((alignedTargetFace.shape[0],alignedTargetFace.shape[1]), 255, dtype=np.float32) |
| 95 | + bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) |
| 96 | + img_white = cv2.warpAffine(img_white, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) |
| 97 | + fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) |
| 98 | + img_white[img_white>20] = 255 |
| 99 | + fthresh = 10 |
| 100 | + fake_diff[fake_diff<fthresh] = 0 |
| 101 | + fake_diff[fake_diff>=fthresh] = 255 |
| 102 | + img_mask = img_white |
| 103 | + mask_h_inds, mask_w_inds = np.where(img_mask==255) |
| 104 | + mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) |
| 105 | + mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) |
| 106 | + mask_size = int(np.sqrt(mask_h*mask_w)) |
| 107 | + k = max(mask_size//10, 10) |
| 108 | + #k = max(mask_size//20, 6) |
| 109 | + #k = 6 |
| 110 | + kernel = np.ones((k,k),np.uint8) |
| 111 | + img_mask = cv2.erode(img_mask,kernel,iterations = 1) |
| 112 | + kernel = np.ones((2,2),np.uint8) |
| 113 | + fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) |
| 114 | + k = max(mask_size//20, 5) |
| 115 | + #k = 3 |
| 116 | + #k = 3 |
| 117 | + kernel_size = (k, k) |
| 118 | + blur_size = tuple(2*i+1 for i in kernel_size) |
| 119 | + img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) |
| 120 | + k = 5 |
| 121 | + kernel_size = (k, k) |
| 122 | + blur_size = tuple(2*i+1 for i in kernel_size) |
| 123 | + fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) |
| 124 | + img_mask /= 255 |
| 125 | + fake_diff /= 255 |
| 126 | + #img_mask = fake_diff |
| 127 | + img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) |
| 128 | + fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) |
| 129 | + fake_merged = fake_merged.astype(np.uint8) |
| 130 | + return fake_merged |
| 131 | + |
| 132 | + def get(self, img, target_face, source_face, paste_back=True): |
| 133 | + aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) |
| 134 | + |
| 135 | + bgr_fake = self.swap(aimg, source_face) |
| 136 | + if not paste_back: |
| 137 | + return bgr_fake, M |
| 138 | + else: |
| 139 | + self.swapAndPasteBack(img, aimg, M, bgr_fake) |
| 140 | + |
0 commit comments