Skip to content

Commit 474c3d9

Browse files
authored
Merge pull request #13 from RongLirr/main
Integrate DTW calculation in training loop
2 parents db5bd90 + d9b4b94 commit 474c3d9

File tree

10 files changed

+1433
-702
lines changed

10 files changed

+1433
-702
lines changed

fluent_pose_synthesis/.style.yapf

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[style]
2+
based_on_style = pep8
3+
column_limit = 120
4+
split_before_named_assigns = false
5+
coalesce_brackets = true
6+
split_before_expression_after_opening_paren = false
7+
split_arguments_when_comma_terminated = false
8+
each_dict_entry_on_separate_line = false
9+
indent_dictionary_value = false

fluent_pose_synthesis/config/default.json

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,31 @@
1111
"dropout": 0.2,
1212
"activation": "gelu",
1313
"ablation": null,
14-
"legacy": false
14+
"legacy": false,
15+
"history_len": 5
1516
},
1617
"diff": {
1718
"noise_schedule": "cosine",
18-
"diffusion_steps": 32,
19+
"diffusion_steps": 8,
1920
"sigma_small": true
2021
},
2122
"trainer": {
22-
"epoch": 300,
23+
"epoch": 500,
2324
"lr": 1e-4,
2425
"batch_size": 1024,
25-
"cond_mask_prob": 0,
26+
"cond_mask_prob": 0.15,
2627
"use_loss_mse": true,
2728
"use_loss_vel": true,
29+
"use_loss_accel": false,
30+
"lambda_vel": 1.0,
31+
"lambda_accel": 1.0,
32+
"guidance_scale": 2.0,
2833
"workers": 4,
29-
"load_num": 200
34+
"load_num": -1,
35+
"validation_max_len": 160,
36+
"validation_chunk_size": 40,
37+
"validation_stop_threshold": 1e-4,
38+
"eval_freq": 1,
39+
"use_amp": false
3040
}
3141
}

fluent_pose_synthesis/config/option.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,32 @@ def add_model_args(parser):
1010
parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads.')
1111
parser.add_argument('--num_layers', type=int, default=4, help='Number of model layers.')
1212

13+
1314
def add_diffusion_args(parser):
1415
parser.add_argument('--noise_schedule', type=str, default='cosine', help='Noise schedule: "cosine", "linear", etc.')
15-
parser.add_argument('--diffusion_steps', type=int, default=4, help='Number of diffusion steps.')
16+
parser.add_argument('--diffusion_steps', type=int, default=8, help='Number of diffusion steps.')
1617
parser.add_argument('--sigma_small', action='store_true', help='Use small sigma values.')
1718

19+
1820
def add_train_args(parser):
19-
parser.add_argument('--epoch', type=int, default=300, help='Number of training epochs.')
21+
parser.add_argument('--epoch', type=int, default=500, help='Number of training epochs.')
2022
parser.add_argument('--lr', type=float, default=0.00005, help='Learning rate.')
2123
parser.add_argument('--lr_anneal_steps', type=int, default=0, help='Annealing steps.')
22-
parser.add_argument('--weight_decay', type=float, default=0.00, help='Weight decay.')
24+
parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay.')
2325
parser.add_argument('--batch_size', type=int, default=1024, help='Batch size.')
24-
parser.add_argument('--cond_mask_prob', type=float, default=0, help='Conditioning mask probability.')
26+
parser.add_argument('--cond_mask_prob', type=float, default=0.15, help='Conditioning mask probability.')
2527
parser.add_argument('--workers', type=int, default=4, help='Data loader workers.')
26-
parser.add_argument('--ema', default=False, type=bool, help='Use Exponential Moving Average (EMA) for model parameters.')
28+
parser.add_argument('--ema', default=False, type=bool,
29+
help='Use Exponential Moving Average (EMA) for model parameters.')
2730
parser.add_argument('--lambda_vel', type=float, default=1.0, help='Weight factor for the velocity loss term.')
31+
parser.add_argument('--use_loss_vel', action='store_true', default=True, help='Enable velocity loss term.')
32+
parser.add_argument('--use_loss_accel', action='store_true', default=False, help='Enable acceleration loss term.')
33+
parser.add_argument('--lambda_accel', type=float, default=1.0, help='Weight factor for the acceleration loss term.')
34+
parser.add_argument('--guidance_scale', type=float, default=2.0,
35+
help='Classifier-free guidance scale for inference.')
2836
parser.add_argument('--load_num', type=int, default=-1, help='Number of models to load.')
37+
parser.add_argument('--use_amp', action='store_true', default=False, help='Use mixed precision training (AMP).')
38+
parser.add_argument('--eval_freq', type=int, default=1, help='Frequency of evaluation during training.')
2939

3040

3141
def config_parse(args):
@@ -49,14 +59,17 @@ def config_parse(args):
4959
config.trainer.lr_anneal_steps = args.lr_anneal_steps
5060
config.trainer.weight_decay = args.weight_decay
5161
config.trainer.batch_size = args.batch_size
52-
config.trainer.ema = True #if args.ema else config.trainer.ema
62+
config.trainer.ema = True #if args.ema else config.trainer.ema
5363
config.trainer.cond_mask_prob = args.cond_mask_prob
5464
config.trainer.workers = args.workers
5565
config.trainer.save_freq = int(config.trainer.epoch // 5)
5666
config.trainer.lambda_vel = args.lambda_vel
67+
config.trainer.use_loss_vel = args.use_loss_vel
68+
config.trainer.use_loss_accel = args.use_loss_accel
69+
config.trainer.lambda_accel = args.lambda_accel
70+
config.trainer.guidance_scale = args.guidance_scale
5771
config.trainer.load_num = args.load_num
5872

59-
6073
# Save directory
6174
data_prefix = args.data.split('/')[-1].split('.')[0]
6275
config.save = f'{args.save}/{args.name}_{data_prefix}' if 'debug' not in config.name else f'{args.save}/{args.name}'

fluent_pose_synthesis/core/models.py

Lines changed: 53 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,19 @@ class OutputProcessMLP(nn.Module):
99
"""
1010
Output process for the Sign Language Pose Diffusion model.
1111
"""
12-
def __init__(self, input_feats, latent_dim, njoints, nfeats, hidden_dim=512): # add hidden_dim as parameter
12+
13+
def __init__(self, input_feats, latent_dim, njoints, nfeats, hidden_dim=512): # add hidden_dim as parameter
1314
super().__init__()
1415
self.input_feats = input_feats
1516
self.latent_dim = latent_dim
1617
self.njoints = njoints
1718
self.nfeats = nfeats
18-
self.hidden_dim = hidden_dim # store hidden dimension
19+
self.hidden_dim = hidden_dim # store hidden dimension
1920

2021
# MLP layers
21-
self.mlp = nn.Sequential(
22-
nn.Linear(self.latent_dim, self.hidden_dim),
23-
nn.SiLU(),
24-
nn.Linear(self.hidden_dim, self.hidden_dim // 2),
25-
nn.SiLU(),
26-
nn.Linear(self.hidden_dim // 2, self.input_feats)
27-
)
22+
self.mlp = nn.Sequential(nn.Linear(self.latent_dim, self.hidden_dim), nn.SiLU(),
23+
nn.Linear(self.hidden_dim, self.hidden_dim // 2), nn.SiLU(),
24+
nn.Linear(self.hidden_dim // 2, self.input_feats))
2825

2926
def forward(self, output):
3027
nframes, bs, d = output.shape
@@ -39,25 +36,11 @@ class SignLanguagePoseDiffusion(nn.Module):
3936
Sign Language Pose Diffusion model.
4037
"""
4138

42-
def __init__(
43-
self,
44-
input_feats: int,
45-
chunk_len: int,
46-
keypoints: int,
47-
dims: int,
48-
latent_dim: int = 256,
49-
ff_size: int = 1024,
50-
num_layers: int = 8,
51-
num_heads: int = 4,
52-
dropout: float = 0.2,
53-
ablation: Optional[str] = None,
54-
activation: str = "gelu",
55-
legacy: bool = False,
56-
arch: str = "trans_enc",
57-
cond_mask_prob: float = 0,
58-
device: Optional[torch.device] = None,
59-
batch_first: bool = True
60-
):
39+
def __init__(self, input_feats: int, chunk_len: int, keypoints: int, dims: int, latent_dim: int = 256,
40+
ff_size: int = 1024, num_layers: int = 8, num_heads: int = 4, dropout: float = 0.2,
41+
ablation: Optional[str] = None, activation: str = "gelu", legacy: bool = False,
42+
arch: str = "trans_enc", cond_mask_prob: float = 0, device: Optional[torch.device] = None,
43+
batch_first: bool = True):
6144
"""
6245
Args:
6346
input_feats (int): Number of input features (keypoints * dimensions).
@@ -105,37 +88,33 @@ def __init__(
10588
# Define sequence encoder based on chosen architecture
10689
if self.arch == "trans_enc":
10790
print(f"Initializing Transformer Encoder (batch_first={self.batch_first})")
108-
encoder_layer = nn.TransformerEncoderLayer(
109-
d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
110-
dropout=dropout, activation=activation, batch_first=self.batch_first
111-
)
91+
encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
92+
dropout=dropout, activation=activation,
93+
batch_first=self.batch_first)
11294
self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
11395
elif self.arch == "trans_dec":
11496
print(f"Initializing Transformer Decoder (batch_first={self.batch_first})")
115-
decoder_layer = nn.TransformerDecoderLayer(
116-
d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
117-
dropout=dropout, activation=activation, batch_first=self.batch_first
118-
)
97+
decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
98+
dropout=dropout, activation=activation,
99+
batch_first=self.batch_first)
119100
self.sequence_encoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
120101
elif self.arch == "gru":
121-
print("Initializing GRU Encoder (batch_first=True)")
122-
self.sequence_encoder = nn.GRU(
123-
latent_dim, latent_dim, num_layers=num_layers, batch_first=True
124-
)
102+
print("Initializing GRU Encoder (batch_first=True)")
103+
self.sequence_encoder = nn.GRU(latent_dim, latent_dim, num_layers=num_layers, batch_first=True)
125104
else:
126105
raise ValueError("Please choose correct architecture [trans_enc, trans_dec, gru]")
127106

128107
# Pose projection: projects latent representation back to pose space.
129108
# The OutputProcess returns (B, keypoints, dims, T); apply a post_transform to get (B, T, keypoints, dims)
130-
self.pose_projection = OutputProcessMLP(input_feats, latent_dim, keypoints, dims, hidden_dim=512)
109+
self.pose_projection = OutputProcessMLP(input_feats, latent_dim, keypoints, dims, hidden_dim=1024)
131110
self.to(self.device)
132111

133112
def forward(
134-
self,
135-
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
136-
disfluent_seq: torch.Tensor, # (B, K, D, T_disfl)
137-
t: torch.Tensor, # (B,)
138-
previous_output: Optional[torch.Tensor] = None # (B, K, D, T_hist)
113+
self,
114+
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
115+
disfluent_seq: torch.Tensor, # (B, K, D, T_disfl)
116+
t: torch.Tensor, # (B,)
117+
previous_output: Optional[torch.Tensor] = None # (B, K, D, T_hist)
139118
) -> torch.Tensor:
140119

141120
# # --- DEBUG: Print Initial Input Shapes ---
@@ -158,35 +137,35 @@ def forward(
158137
T_chunk = fluent_clip.shape[-1]
159138

160139
# 1. Embed Timestep
161-
_t_emb_raw = self.embed_timestep(t) # Expected (B, D)
140+
_t_emb_raw = self.embed_timestep(t) # Expected (B, D)
162141
# print(f"[DEBUG FWD 1a] Raw t_emb shape: {_t_emb_raw.shape}")
163-
t_emb = _t_emb_raw.permute(1, 0, 2)
142+
t_emb = _t_emb_raw.permute(1, 0, 2).contiguous()
164143
# print(f"[DEBUG FWD 1b] Final t_emb shape: {t_emb.shape}")
165144

166145
# 2. Embed Disfluent Sequence (Condition)
167-
_disfluent_emb_raw = self.disfluent_encoder(disfluent_seq) # Expected (T_disfl, B, D)
146+
_disfluent_emb_raw = self.disfluent_encoder(disfluent_seq) # Expected (T_disfl, B, D)
168147
# print(f"[DEBUG FWD 2a] Raw disfluent_emb shape: {_disfluent_emb_raw.shape}")
169-
disfluent_emb = _disfluent_emb_raw.permute(1, 0, 2) # Expected (B, T_disfl, D)
148+
disfluent_emb = _disfluent_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_disfl, D)
170149
# print(f"[DEBUG FWD 2b] Final disfluent_emb shape: {disfluent_emb.shape}")
171150

172151
# 3. Embed Previous Output (History), if available
173152
embeddings_to_concat = [t_emb, disfluent_emb]
174153
# print("[DEBUG FWD 3a] Processing previous_output...")
175154
if previous_output is not None and previous_output.shape[-1] > 0:
176155
# print(f"[DEBUG FWD 3b] History Input shape: {previous_output.shape}")
177-
_prev_out_emb_raw = self.fluent_encoder(previous_output) # Expected (T_hist, B, D)
156+
_prev_out_emb_raw = self.fluent_encoder(previous_output) # Expected (T_hist, B, D)
178157
# print(f"[DEBUG FWD 3c] Raw prev_out_emb shape: {_prev_out_emb_raw.shape}")
179-
prev_out_emb = _prev_out_emb_raw.permute(1, 0, 2) # Expected (B, T_hist, D)
158+
prev_out_emb = _prev_out_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_hist, D)
180159
# print(f"[DEBUG FWD 3d] Final prev_out_emb shape: {prev_out_emb.shape}")
181160
embeddings_to_concat.append(prev_out_emb)
182161
else:
183162
# print("[DEBUG FWD 3b] No previous_output provided or it's empty.")
184163
pass
185164

186165
# 4. Embed Current Fluent Clip (Noisy Target 'x')
187-
_fluent_emb_raw = self.fluent_encoder(fluent_clip) # Expected (T_chunk, B, D)
166+
_fluent_emb_raw = self.fluent_encoder(fluent_clip) # Expected (T_chunk, B, D)
188167
# print(f"[DEBUG FWD 4a] Raw fluent_emb shape: {_fluent_emb_raw.shape}")
189-
fluent_emb = _fluent_emb_raw.permute(1, 0, 2) # Expected (B, T_chunk, D)
168+
fluent_emb = _fluent_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_chunk, D)
190169
# print(f"[DEBUG FWD 4b] Final fluent_emb shape: {fluent_emb.shape}")
191170
embeddings_to_concat.append(fluent_emb)
192171

@@ -198,33 +177,33 @@ def forward(
198177
# print(f"[DEBUG FWD 6a] xseq shape before PositionalEncoding: {xseq.shape}")
199178
# Adapt based on PositionalEncoding expectation (T, B, D) vs batch_first
200179
if self.batch_first:
201-
xseq_permuted = xseq.permute(1, 0, 2) # (T_total, B, D)
180+
xseq_permuted = xseq.permute(1, 0, 2).contiguous() # (T_total, B, D)
202181
# print(f"[DEBUG FWD 6b] xseq permuted for PosEnc: {xseq_permuted.shape}")
203182
xseq_encoded = self.sequence_pos_encoder(xseq_permuted)
204183
# print(f"[DEBUG FWD 6c] xseq after PosEnc: {xseq_encoded.shape}")
205-
xseq = xseq_encoded.permute(1, 0, 2) # Back to (B, T_total, D)
184+
xseq = xseq_encoded.permute(1, 0, 2) # Back to (B, T_total, D)
206185
# print(f"[DEBUG FWD 6d] xseq permuted back: {xseq.shape}")
207186
else:
208-
# If not batch_first, assume xseq should be (T, B, D) already
209-
# Need to adjust concatenation and permutations above if batch_first=False
210-
xseq = xseq.permute(1, 0, 2) # Assume needs (T, B, D)
187+
# If not batch_first, assume xseq should be (T, B, D) already
188+
# Need to adjust concatenation and permutations above if batch_first=False
189+
xseq = xseq.permute(1, 0, 2) # Assume needs (T, B, D)
211190
# print(f"[DEBUG FWD 6b] xseq permuted for PosEnc (batch_first=False): {xseq.shape}")
212-
xseq = self.sequence_pos_encoder(xseq)
213-
# print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}")
214-
# Keep as (T, B, D) if encoder needs it
191+
xseq = self.sequence_pos_encoder(xseq)
192+
# print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}")
193+
# Keep as (T, B, D) if encoder needs it
215194

216195
# 7. Process through sequence encoder
217196
# print(f"[DEBUG FWD 7a] Input to sequence_encoder ({self.arch}) shape: {xseq.shape}")
218197
if self.arch == "trans_enc":
219-
x_encoded = self.sequence_encoder(xseq)
198+
x_encoded = self.sequence_encoder(xseq)
220199
elif self.arch == "gru":
221-
x_encoded, _ = self.sequence_encoder(xseq)
200+
x_encoded, _ = self.sequence_encoder(xseq)
222201
elif self.arch == "trans_dec":
223-
memory = xseq
224-
tgt = xseq
225-
x_encoded = self.sequence_encoder(tgt=tgt, memory=memory)
202+
memory = xseq
203+
tgt = xseq
204+
x_encoded = self.sequence_encoder(tgt=tgt, memory=memory)
226205
else:
227-
raise ValueError("Unsupported architecture")
206+
raise ValueError("Unsupported architecture")
228207
# print(f"[DEBUG FWD 7b] Output from sequence_encoder shape: {x_encoded.shape}")
229208

230209
# 8. Extract the output corresponding to the target fluent_clip
@@ -249,30 +228,11 @@ def forward(
249228

250229
return output
251230

252-
253-
# def mingyi_forward(
254-
# self, fluent_clip: torch.Tensor, disfluent_seq: torch.Tensor, t: torch.Tensor
255-
# ) -> torch.Tensor:
256-
257-
# batch_size, keypoints, dims, time = fluent_clip.shape
258-
259-
# t_emb = self.embed_timestep(t)
260-
# disfluent_emb = self.disfluent_encoder(disfluent_seq)
261-
# fluent_emb = self.fluent_encoder(fluent_clip)
262-
263-
# xseq = torch.cat((t_emb, disfluent_emb, fluent_emb), axis=0)
264-
# xseq = self.sequence_pos_encoder(xseq)
265-
266-
# x_out = self.sequence_encoder(xseq)[:time]
267-
# output = self.pose_projection(x_out)
268-
269-
# return output
270-
271231
def interface(
272-
self,
273-
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
274-
t: torch.Tensor, # (B,)
275-
y: dict[str, torch.Tensor] # Conditions dict
232+
self,
233+
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
234+
t: torch.Tensor, # (B,)
235+
y: dict[str, torch.Tensor] # Conditions dict
276236
) -> torch.Tensor:
277237
"""
278238
Interface for Classifier-Free Guidance (CFG). Handles previous_output.
@@ -283,13 +243,8 @@ def interface(
283243
previous_output = y.get("previous_output", None)
284244

285245
# Apply CFG: randomly drop the condition with probability cond_mask_prob
286-
keep_batch_idx = torch.rand(batch_size, device=disfluent_seq.device) < (1-self.cond_mask_prob)
246+
keep_batch_idx = torch.rand(batch_size, device=disfluent_seq.device) < (1 - self.cond_mask_prob)
287247
disfluent_seq = disfluent_seq * keep_batch_idx.view((batch_size, 1, 1, 1))
288248

289249
# Call the forward function
290-
return self.forward(
291-
fluent_clip=fluent_clip,
292-
disfluent_seq=disfluent_seq,
293-
t=t,
294-
previous_output=previous_output
295-
)
250+
return self.forward(fluent_clip=fluent_clip, disfluent_seq=disfluent_seq, t=t, previous_output=previous_output)

0 commit comments

Comments
 (0)