Skip to content

Commit

Permalink
Delete useless code
Browse files Browse the repository at this point in the history
  • Loading branch information
radekd91 committed Feb 10, 2023
1 parent 5954b45 commit c47f155
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions gdl/models/DECA.py
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,7 @@ def decompose_code(self, code):

deca_code_list_copy = deca_code_list.copy()

# self.E_mica.cfg.model.n_shape

#TODO: clean this if-else block up
if self.config.exp_deca_global_pose and self.config.exp_deca_jaw_pose:
Expand Down Expand Up @@ -3284,3 +3285,144 @@ def train(self, mode: bool = True):
self.D_detail.eval()
return self



class EMICA(ExpDECA):

def __init__(self, config):
self.use_mica_shape_dim = True
# self.use_mica_shape_dim = False
from .mica.config import get_cfg_defaults
self.mica_cfg = get_cfg_defaults()
super().__init__(config)

def _create_model(self):
# 1) Initialize DECA
super()._create_model()
from .mica.mica import MICA
#TODO: MICA uses FLAME
# 1) This is redundant - get rid of it
# 2) Make sure it's the same FLAME as EMOCA
if Path(self.config.mica_model_path).exists():
mica_path = self.config.mica_model_path
else:
from gdl.utils.other import get_path_to_assets
mica_path = get_path_to_assets() / self.config.mica_model_path
assert mica_path.exists(), f"MICA model path does not exist: '{mica_path}'"

self.mica_cfg.pretrained_model_path = str(mica_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.E_mica = MICA(self.mica_cfg, device, str(mica_path), instantiate_flame=False)
# E_mica should be fixed
self.E_mica.requires_grad_(False)
self.E_mica.testing = True

# preprocessing for MICA
if self.config.mica_preprocessing:
from insightface.app import FaceAnalysis
self.app = FaceAnalysis(name='antelopev2', providers=['CUDAExecutionProvider'])
self.app.prepare(ctx_id=0, det_size=(224, 224))


def _get_num_shape_params(self):
if self.use_mica_shape_dim:
return self.mica_cfg.model.n_shape
return self.config.n_shape

def _get_coarse_trainable_parameters(self):
# MICA is not trainable so we don't wanna add it
return super()._get_coarse_trainable_parameters()


def train(self, mode: bool = True):
super().train(mode)
self.E_mica.train(False) # MICA is pretrained and will be set to EVAL at all times


def _encode_flame(self, images):

if self.config.mica_preprocessing:
mica_image = self._dirty_image_preprocessing(images)
else:
mica_image = F.interpolate(images, (112,112), mode='bilinear', align_corners=False)

deca_code, exp_deca_code = super()._encode_flame(images)
mica_code = self.E_mica.encode(images, mica_image)
mica_code = self.E_mica.decode(mica_code, predict_vertices=False)
return deca_code, exp_deca_code, mica_code['pred_shape_code']

def _dirty_image_preprocessing(self, input_image):
# breaks whatever gradient flow that may have gone into the image creation process
from gdl.models.mica.detector import get_center, get_arcface_input
from insightface.app.common import Face

image = input_image.detach().clone().cpu().numpy() * 255.
# b,c,h,w to b,h,w,c
image = image.transpose((0,2,3,1))

min_det_score = 0.5
image_list = list(image)
aligned_image_list = []
for i, img in enumerate(image_list):
bboxes, kpss = self.app.det_model.detect(img, max_num=0, metric='default')
if bboxes.shape[0] == 0:
aimg = resize(img, output_shape=(112,112), preserve_range=True)
aligned_image_list.append(aimg)
raise RuntimeError("No faces detected")
continue
i = get_center(bboxes, image)
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
# if det_score < min_det_score:
# continue
kps = None
if kpss is not None:
kps = kpss[i]

face = Face(bbox=bbox, kps=kps, det_score=det_score)
blob, aimg = get_arcface_input(face, img)
aligned_image_list.append(aimg)
aligned_images = np.array(aligned_image_list)
# b,h,w,c to b,c,h,w
aligned_images = aligned_images.transpose((0,3,1,2))
# to torch to correct device
aligned_images = torch.from_numpy(aligned_images).to(input_image.device)
return aligned_images

def decompose_code(self, code):
deca_code = code[0]
expdeca_code = code[1]
mica_code = code[2]

code_list, deca_code_list_copy = super().decompose_code((deca_code, expdeca_code), )

id_idx = 0 # identity is the first part of the vector
# assert self.config.n_shape == mica_code.shape[-1]
# assert code_list[id_idx].shape[-1] == mica_code.shape[-1]
if self.use_mica_shape_dim:
code_list[id_idx] = mica_code
else:
code_list[id_idx] = mica_code[..., :self.config.n_shape]
return code_list, deca_code_list_copy


def instantiate_deca(cfg, stage, prefix, checkpoint=None, checkpoint_kwargs=None):
"""
Function that instantiates a DecaModule from checkpoint or config
"""

if checkpoint is None:
deca = DecaModule(cfg.model, cfg.learning, cfg.inout, prefix)
if cfg.model.resume_training:
# This load the DECA model weights from the original DECA release
print("[WARNING] Loading EMOCA checkpoint pretrained by the old code")
deca.deca._load_old_checkpoint()
else:
checkpoint_kwargs = checkpoint_kwargs or {}
deca = DecaModule.load_from_checkpoint(checkpoint_path=checkpoint, strict=False, **checkpoint_kwargs)
if stage == 'train':
mode = True
else:
mode = False
deca.reconfigure(cfg.model, cfg.inout, cfg.learning, prefix, downgrade_ok=True, train=mode)
return deca

0 comments on commit c47f155

Please sign in to comment.