forked from soubhiksanyal/FLAME_PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFLAME.py
249 lines (206 loc) · 11.6 KB
/
FLAME.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""
FLAME Layer: Implementation of the 3D Statistical Face model in PyTorch
It is designed in a way to directly plug in as a decoder layer in a
Deep learning framework for training and testing
It can also be used for 2D or 3D optimisation applications
Author: Soubhik Sanyal
Copyright (c) 2019, Soubhik Sanyal
All rights reserved.
Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this
computer program.
You can only use this computer program if you have closed a license agreement with MPG or you get the right to use
the computer program from someone who is authorized to grant you that right.
Any use of the computer program without a valid license is prohibited and liable to prosecution.
Copyright 2019 Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its
Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics.
All rights reserved.
More information about FLAME is available at http://flame.is.tue.mpg.de.
For questions regarding the PyTorch implementation please contact soubhik.sanyal@tuebingen.mpg.de
"""
# Modified from smplx code [https://github.com/vchoutas/smplx] for FLAME
import numpy as np
import torch
import torch.nn as nn
import pickle
from smplx.lbs import lbs, batch_rodrigues, vertices2landmarks, find_dynamic_lmk_idx_and_bcoords
from smplx.utils import Struct, to_tensor, to_np, rot_mat_to_euler
class FLAME(nn.Module):
"""
Given flame parameters this class generates a differentiable FLAME function
which outputs the a mesh and 3D facial landmarks
"""
def __init__(self, config):
super(FLAME, self).__init__()
print("creating the FLAME Decoder")
with open(config.flame_model_path, 'rb') as f:
self.flame_model = Struct(**pickle.load(f, encoding='latin1'))
self.NECK_IDX = 1
self.batch_size = config.batch_size
self.dtype = torch.float32
self.use_face_contour = config.use_face_contour
self.faces = self.flame_model.f
self.register_buffer('faces_tensor',
to_tensor(to_np(self.faces, dtype=np.int64),
dtype=torch.long))
# Fixing remaining Shape betas
# There are total 300 shape parameters to control FLAME; But one can use the first few parameters to express
# the shape. For example 100 shape parameters are used for RingNet project
default_shape = torch.zeros([self.batch_size, 300-config.shape_params],
dtype=self.dtype, requires_grad=False)
self.register_parameter('shape_betas', nn.Parameter(default_shape,
requires_grad=False))
# Fixing remaining expression betas
# There are total 100 shape expression parameters to control FLAME; But one can use the first few parameters to express
# the expression. For example 50 expression parameters are used for RingNet project
default_exp = torch.zeros([self.batch_size, 100 - config.expression_params],
dtype=self.dtype, requires_grad=False)
self.register_parameter('expression_betas', nn.Parameter(default_exp,
requires_grad=False))
# Eyeball and neck rotation
default_eyball_pose = torch.zeros([self.batch_size, 6],
dtype=self.dtype, requires_grad=False)
self.register_parameter('eye_pose', nn.Parameter(default_eyball_pose,
requires_grad=False))
default_neck_pose = torch.zeros([self.batch_size, 3],
dtype=self.dtype, requires_grad=False)
self.register_parameter('neck_pose', nn.Parameter(default_neck_pose,
requires_grad=False))
# Fixing 3D translation since we use translation in the image plane
self.use_3D_translation = config.use_3D_translation
default_transl = torch.zeros([self.batch_size, 3],
dtype=self.dtype, requires_grad=False)
self.register_parameter(
'transl',
nn.Parameter(default_transl, requires_grad=False))
# The vertices of the template model
self.register_buffer('v_template',
to_tensor(to_np(self.flame_model.v_template),
dtype=self.dtype))
# The shape components
shapedirs = self.flame_model.shapedirs
# The shape components
self.register_buffer(
'shapedirs',
to_tensor(to_np(shapedirs), dtype=self.dtype))
j_regressor = to_tensor(to_np(
self.flame_model.J_regressor), dtype=self.dtype)
self.register_buffer('J_regressor', j_regressor)
# Pose blend shape basis
num_pose_basis = self.flame_model.posedirs.shape[-1]
posedirs = np.reshape(self.flame_model.posedirs, [-1, num_pose_basis]).T
self.register_buffer('posedirs',
to_tensor(to_np(posedirs), dtype=self.dtype))
# indices of parents for each joints
parents = to_tensor(to_np(self.flame_model.kintree_table[0])).long()
parents[0] = -1
self.register_buffer('parents', parents)
self.register_buffer('lbs_weights',
to_tensor(to_np(self.flame_model.weights), dtype=self.dtype))
# Static and Dynamic Landmark embeddings for FLAME
with open(config.static_landmark_embedding_path, 'rb') as f:
static_embeddings = Struct(**pickle.load(f, encoding='latin1'))
lmk_faces_idx = (static_embeddings.lmk_face_idx).astype(np.int64)
self.register_buffer('lmk_faces_idx',
torch.tensor(lmk_faces_idx, dtype=torch.long))
lmk_bary_coords = static_embeddings.lmk_b_coords
self.register_buffer('lmk_bary_coords',
torch.tensor(lmk_bary_coords, dtype=self.dtype))
if self.use_face_contour:
conture_embeddings = np.load(config.dynamic_landmark_embedding_path,
allow_pickle=True, encoding='latin1')
conture_embeddings = conture_embeddings[()]
dynamic_lmk_faces_idx = np.array(conture_embeddings['lmk_face_idx']).astype(np.int64)
dynamic_lmk_faces_idx = torch.tensor(
dynamic_lmk_faces_idx,
dtype=torch.long)
self.register_buffer('dynamic_lmk_faces_idx',
dynamic_lmk_faces_idx)
dynamic_lmk_bary_coords = conture_embeddings['lmk_b_coords']
dynamic_lmk_bary_coords = torch.tensor(
dynamic_lmk_bary_coords, dtype=self.dtype)
self.register_buffer('dynamic_lmk_bary_coords',
dynamic_lmk_bary_coords)
neck_kin_chain = []
curr_idx = torch.tensor(self.NECK_IDX, dtype=torch.long)
while curr_idx != -1:
neck_kin_chain.append(curr_idx)
curr_idx = self.parents[curr_idx]
self.register_buffer('neck_kin_chain',
torch.stack(neck_kin_chain))
def _find_dynamic_lmk_idx_and_bcoords(self, vertices, pose, dynamic_lmk_faces_idx,
dynamic_lmk_b_coords,
neck_kin_chain, dtype=torch.float32):
"""
Selects the face contour depending on the reletive position of the head
Input:
vertices: N X num_of_vertices X 3
pose: N X full pose
dynamic_lmk_faces_idx: The list of contour face indexes
dynamic_lmk_b_coords: The list of contour barycentric weights
neck_kin_chain: The tree to consider for the relative rotation
dtype: Data type
return:
The contour face indexes and the corresponding barycentric weights
Source: Modified for batches from https://github.com/vchoutas/smplx
"""
batch_size = vertices.shape[0]
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
neck_kin_chain)
rot_mats = batch_rodrigues(
aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
rel_rot_mat = torch.eye(3, device=vertices.device,
dtype=dtype).unsqueeze_(dim=0).expand(batch_size, -1, -1)
for idx in range(len(neck_kin_chain)):
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
y_rot_angle = torch.round(
torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
max=39)).to(dtype=torch.long)
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
y_rot_angle = (neg_mask * neg_vals +
(1 - neg_mask) * y_rot_angle)
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
0, y_rot_angle)
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
0, y_rot_angle)
return dyn_lmk_faces_idx, dyn_lmk_b_coords
def forward(self, shape_params=None, expression_params=None, pose_params=None, neck_pose=None, eye_pose=None, transl=None):
"""
Input:
shape_params: N X number of shape parameters
expression_params: N X number of expression parameters
pose_params: N X number of pose parameters
return:
vertices: N X V X 3
landmarks: N X number of landmarks X 3
"""
betas = torch.cat([shape_params,self.shape_betas, expression_params, self.expression_betas], dim=1)
neck_pose = (neck_pose if neck_pose is not None else self.neck_pose)
eye_pose = (eye_pose if eye_pose is not None else self.eye_pose)
transl = (transl if transl is not None else self.transl)
full_pose = torch.cat([pose_params[:,:3], neck_pose, pose_params[:,3:], eye_pose], dim=1)
template_vertices = self.v_template.unsqueeze(0).repeat(self.batch_size, 1, 1)
vertices, _ = lbs(betas, full_pose, template_vertices,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.lbs_weights, dtype=self.dtype)
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).repeat(
self.batch_size, 1)
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
self.batch_size, 1, 1)
if self.use_face_contour:
dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
vertices, full_pose, self.dynamic_lmk_faces_idx,
self.dynamic_lmk_bary_coords,
self.neck_kin_chain, dtype=self.dtype)
lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
lmk_bary_coords = torch.cat(
[dyn_lmk_bary_coords, lmk_bary_coords], 1)
landmarks = vertices2landmarks(vertices, self.faces_tensor,
lmk_faces_idx,
lmk_bary_coords)
if self.use_3D_translation:
landmarks += transl.unsqueeze(dim=1)
vertices += transl.unsqueeze(dim=1)
return vertices, landmarks