Skip to content

Integrate DTW calculation in training loop #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fluent_pose_synthesis/.style.yapf
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[style]
based_on_style = pep8
column_limit = 120
split_before_named_assigns = false
coalesce_brackets = true
split_before_expression_after_opening_paren = false
split_arguments_when_comma_terminated = false
each_dict_entry_on_separate_line = false
indent_dictionary_value = false
20 changes: 15 additions & 5 deletions fluent_pose_synthesis/config/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,31 @@
"dropout": 0.2,
"activation": "gelu",
"ablation": null,
"legacy": false
"legacy": false,
"history_len": 5
},
"diff": {
"noise_schedule": "cosine",
"diffusion_steps": 32,
"diffusion_steps": 8,
"sigma_small": true
},
"trainer": {
"epoch": 300,
"epoch": 500,
"lr": 1e-4,
"batch_size": 1024,
"cond_mask_prob": 0,
"cond_mask_prob": 0.15,
"use_loss_mse": true,
"use_loss_vel": true,
"use_loss_accel": false,
"lambda_vel": 1.0,
"lambda_accel": 1.0,
"guidance_scale": 2.0,
"workers": 4,
"load_num": 200
"load_num": -1,
"validation_max_len": 160,
"validation_chunk_size": 40,
"validation_stop_threshold": 1e-4,
"eval_freq": 1,
"use_amp": false
}
}
27 changes: 20 additions & 7 deletions fluent_pose_synthesis/config/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,32 @@ def add_model_args(parser):
parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads.')
parser.add_argument('--num_layers', type=int, default=4, help='Number of model layers.')


def add_diffusion_args(parser):
parser.add_argument('--noise_schedule', type=str, default='cosine', help='Noise schedule: "cosine", "linear", etc.')
parser.add_argument('--diffusion_steps', type=int, default=4, help='Number of diffusion steps.')
parser.add_argument('--diffusion_steps', type=int, default=8, help='Number of diffusion steps.')
parser.add_argument('--sigma_small', action='store_true', help='Use small sigma values.')


def add_train_args(parser):
parser.add_argument('--epoch', type=int, default=300, help='Number of training epochs.')
parser.add_argument('--epoch', type=int, default=500, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.00005, help='Learning rate.')
parser.add_argument('--lr_anneal_steps', type=int, default=0, help='Annealing steps.')
parser.add_argument('--weight_decay', type=float, default=0.00, help='Weight decay.')
parser.add_argument('--weight_decay', type=float, default=0.001, help='Weight decay.')
parser.add_argument('--batch_size', type=int, default=1024, help='Batch size.')
parser.add_argument('--cond_mask_prob', type=float, default=0, help='Conditioning mask probability.')
parser.add_argument('--cond_mask_prob', type=float, default=0.15, help='Conditioning mask probability.')
parser.add_argument('--workers', type=int, default=4, help='Data loader workers.')
parser.add_argument('--ema', default=False, type=bool, help='Use Exponential Moving Average (EMA) for model parameters.')
parser.add_argument('--ema', default=False, type=bool,
help='Use Exponential Moving Average (EMA) for model parameters.')
parser.add_argument('--lambda_vel', type=float, default=1.0, help='Weight factor for the velocity loss term.')
parser.add_argument('--use_loss_vel', action='store_true', default=True, help='Enable velocity loss term.')
parser.add_argument('--use_loss_accel', action='store_true', default=False, help='Enable acceleration loss term.')
parser.add_argument('--lambda_accel', type=float, default=1.0, help='Weight factor for the acceleration loss term.')
parser.add_argument('--guidance_scale', type=float, default=2.0,
help='Classifier-free guidance scale for inference.')
parser.add_argument('--load_num', type=int, default=-1, help='Number of models to load.')
parser.add_argument('--use_amp', action='store_true', default=False, help='Use mixed precision training (AMP).')
parser.add_argument('--eval_freq', type=int, default=1, help='Frequency of evaluation during training.')


def config_parse(args):
Expand All @@ -49,14 +59,17 @@ def config_parse(args):
config.trainer.lr_anneal_steps = args.lr_anneal_steps
config.trainer.weight_decay = args.weight_decay
config.trainer.batch_size = args.batch_size
config.trainer.ema = True #if args.ema else config.trainer.ema
config.trainer.ema = True #if args.ema else config.trainer.ema
config.trainer.cond_mask_prob = args.cond_mask_prob
config.trainer.workers = args.workers
config.trainer.save_freq = int(config.trainer.epoch // 5)
config.trainer.lambda_vel = args.lambda_vel
config.trainer.use_loss_vel = args.use_loss_vel
config.trainer.use_loss_accel = args.use_loss_accel
config.trainer.lambda_accel = args.lambda_accel
config.trainer.guidance_scale = args.guidance_scale
config.trainer.load_num = args.load_num


# Save directory
data_prefix = args.data.split('/')[-1].split('.')[0]
config.save = f'{args.save}/{args.name}_{data_prefix}' if 'debug' not in config.name else f'{args.save}/{args.name}'
Expand Down
151 changes: 53 additions & 98 deletions fluent_pose_synthesis/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@ class OutputProcessMLP(nn.Module):
"""
Output process for the Sign Language Pose Diffusion model.
"""
def __init__(self, input_feats, latent_dim, njoints, nfeats, hidden_dim=512): # add hidden_dim as parameter

def __init__(self, input_feats, latent_dim, njoints, nfeats, hidden_dim=512): # add hidden_dim as parameter
super().__init__()
self.input_feats = input_feats
self.latent_dim = latent_dim
self.njoints = njoints
self.nfeats = nfeats
self.hidden_dim = hidden_dim # store hidden dimension
self.hidden_dim = hidden_dim # store hidden dimension

# MLP layers
self.mlp = nn.Sequential(
nn.Linear(self.latent_dim, self.hidden_dim),
nn.SiLU(),
nn.Linear(self.hidden_dim, self.hidden_dim // 2),
nn.SiLU(),
nn.Linear(self.hidden_dim // 2, self.input_feats)
)
self.mlp = nn.Sequential(nn.Linear(self.latent_dim, self.hidden_dim), nn.SiLU(),
nn.Linear(self.hidden_dim, self.hidden_dim // 2), nn.SiLU(),
nn.Linear(self.hidden_dim // 2, self.input_feats))

def forward(self, output):
nframes, bs, d = output.shape
Expand All @@ -39,25 +36,11 @@ class SignLanguagePoseDiffusion(nn.Module):
Sign Language Pose Diffusion model.
"""

def __init__(
self,
input_feats: int,
chunk_len: int,
keypoints: int,
dims: int,
latent_dim: int = 256,
ff_size: int = 1024,
num_layers: int = 8,
num_heads: int = 4,
dropout: float = 0.2,
ablation: Optional[str] = None,
activation: str = "gelu",
legacy: bool = False,
arch: str = "trans_enc",
cond_mask_prob: float = 0,
device: Optional[torch.device] = None,
batch_first: bool = True
):
def __init__(self, input_feats: int, chunk_len: int, keypoints: int, dims: int, latent_dim: int = 256,
ff_size: int = 1024, num_layers: int = 8, num_heads: int = 4, dropout: float = 0.2,
ablation: Optional[str] = None, activation: str = "gelu", legacy: bool = False,
arch: str = "trans_enc", cond_mask_prob: float = 0, device: Optional[torch.device] = None,
batch_first: bool = True):
"""
Args:
input_feats (int): Number of input features (keypoints * dimensions).
Expand Down Expand Up @@ -105,37 +88,33 @@ def __init__(
# Define sequence encoder based on chosen architecture
if self.arch == "trans_enc":
print(f"Initializing Transformer Encoder (batch_first={self.batch_first})")
encoder_layer = nn.TransformerEncoderLayer(
d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
dropout=dropout, activation=activation, batch_first=self.batch_first
)
encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
dropout=dropout, activation=activation,
batch_first=self.batch_first)
self.sequence_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
elif self.arch == "trans_dec":
print(f"Initializing Transformer Decoder (batch_first={self.batch_first})")
decoder_layer = nn.TransformerDecoderLayer(
d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
dropout=dropout, activation=activation, batch_first=self.batch_first
)
decoder_layer = nn.TransformerDecoderLayer(d_model=latent_dim, nhead=num_heads, dim_feedforward=ff_size,
dropout=dropout, activation=activation,
batch_first=self.batch_first)
self.sequence_encoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
elif self.arch == "gru":
print("Initializing GRU Encoder (batch_first=True)")
self.sequence_encoder = nn.GRU(
latent_dim, latent_dim, num_layers=num_layers, batch_first=True
)
print("Initializing GRU Encoder (batch_first=True)")
self.sequence_encoder = nn.GRU(latent_dim, latent_dim, num_layers=num_layers, batch_first=True)
else:
raise ValueError("Please choose correct architecture [trans_enc, trans_dec, gru]")

# Pose projection: projects latent representation back to pose space.
# The OutputProcess returns (B, keypoints, dims, T); apply a post_transform to get (B, T, keypoints, dims)
self.pose_projection = OutputProcessMLP(input_feats, latent_dim, keypoints, dims, hidden_dim=512)
self.pose_projection = OutputProcessMLP(input_feats, latent_dim, keypoints, dims, hidden_dim=1024)
self.to(self.device)

def forward(
self,
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
disfluent_seq: torch.Tensor, # (B, K, D, T_disfl)
t: torch.Tensor, # (B,)
previous_output: Optional[torch.Tensor] = None # (B, K, D, T_hist)
self,
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
disfluent_seq: torch.Tensor, # (B, K, D, T_disfl)
t: torch.Tensor, # (B,)
previous_output: Optional[torch.Tensor] = None # (B, K, D, T_hist)
) -> torch.Tensor:

# # --- DEBUG: Print Initial Input Shapes ---
Expand All @@ -158,35 +137,35 @@ def forward(
T_chunk = fluent_clip.shape[-1]

# 1. Embed Timestep
_t_emb_raw = self.embed_timestep(t) # Expected (B, D)
_t_emb_raw = self.embed_timestep(t) # Expected (B, D)
# print(f"[DEBUG FWD 1a] Raw t_emb shape: {_t_emb_raw.shape}")
t_emb = _t_emb_raw.permute(1, 0, 2)
t_emb = _t_emb_raw.permute(1, 0, 2).contiguous()
# print(f"[DEBUG FWD 1b] Final t_emb shape: {t_emb.shape}")

# 2. Embed Disfluent Sequence (Condition)
_disfluent_emb_raw = self.disfluent_encoder(disfluent_seq) # Expected (T_disfl, B, D)
_disfluent_emb_raw = self.disfluent_encoder(disfluent_seq) # Expected (T_disfl, B, D)
# print(f"[DEBUG FWD 2a] Raw disfluent_emb shape: {_disfluent_emb_raw.shape}")
disfluent_emb = _disfluent_emb_raw.permute(1, 0, 2) # Expected (B, T_disfl, D)
disfluent_emb = _disfluent_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_disfl, D)
# print(f"[DEBUG FWD 2b] Final disfluent_emb shape: {disfluent_emb.shape}")

# 3. Embed Previous Output (History), if available
embeddings_to_concat = [t_emb, disfluent_emb]
# print("[DEBUG FWD 3a] Processing previous_output...")
if previous_output is not None and previous_output.shape[-1] > 0:
# print(f"[DEBUG FWD 3b] History Input shape: {previous_output.shape}")
_prev_out_emb_raw = self.fluent_encoder(previous_output) # Expected (T_hist, B, D)
_prev_out_emb_raw = self.fluent_encoder(previous_output) # Expected (T_hist, B, D)
# print(f"[DEBUG FWD 3c] Raw prev_out_emb shape: {_prev_out_emb_raw.shape}")
prev_out_emb = _prev_out_emb_raw.permute(1, 0, 2) # Expected (B, T_hist, D)
prev_out_emb = _prev_out_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_hist, D)
# print(f"[DEBUG FWD 3d] Final prev_out_emb shape: {prev_out_emb.shape}")
embeddings_to_concat.append(prev_out_emb)
else:
# print("[DEBUG FWD 3b] No previous_output provided or it's empty.")
pass

# 4. Embed Current Fluent Clip (Noisy Target 'x')
_fluent_emb_raw = self.fluent_encoder(fluent_clip) # Expected (T_chunk, B, D)
_fluent_emb_raw = self.fluent_encoder(fluent_clip) # Expected (T_chunk, B, D)
# print(f"[DEBUG FWD 4a] Raw fluent_emb shape: {_fluent_emb_raw.shape}")
fluent_emb = _fluent_emb_raw.permute(1, 0, 2) # Expected (B, T_chunk, D)
fluent_emb = _fluent_emb_raw.permute(1, 0, 2).contiguous() # Expected (B, T_chunk, D)
# print(f"[DEBUG FWD 4b] Final fluent_emb shape: {fluent_emb.shape}")
embeddings_to_concat.append(fluent_emb)

Expand All @@ -198,33 +177,33 @@ def forward(
# print(f"[DEBUG FWD 6a] xseq shape before PositionalEncoding: {xseq.shape}")
# Adapt based on PositionalEncoding expectation (T, B, D) vs batch_first
if self.batch_first:
xseq_permuted = xseq.permute(1, 0, 2) # (T_total, B, D)
xseq_permuted = xseq.permute(1, 0, 2).contiguous() # (T_total, B, D)
# print(f"[DEBUG FWD 6b] xseq permuted for PosEnc: {xseq_permuted.shape}")
xseq_encoded = self.sequence_pos_encoder(xseq_permuted)
# print(f"[DEBUG FWD 6c] xseq after PosEnc: {xseq_encoded.shape}")
xseq = xseq_encoded.permute(1, 0, 2) # Back to (B, T_total, D)
xseq = xseq_encoded.permute(1, 0, 2) # Back to (B, T_total, D)
# print(f"[DEBUG FWD 6d] xseq permuted back: {xseq.shape}")
else:
# If not batch_first, assume xseq should be (T, B, D) already
# Need to adjust concatenation and permutations above if batch_first=False
xseq = xseq.permute(1, 0, 2) # Assume needs (T, B, D)
# If not batch_first, assume xseq should be (T, B, D) already
# Need to adjust concatenation and permutations above if batch_first=False
xseq = xseq.permute(1, 0, 2) # Assume needs (T, B, D)
# print(f"[DEBUG FWD 6b] xseq permuted for PosEnc (batch_first=False): {xseq.shape}")
xseq = self.sequence_pos_encoder(xseq)
# print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}")
# Keep as (T, B, D) if encoder needs it
xseq = self.sequence_pos_encoder(xseq)
# print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}")
# Keep as (T, B, D) if encoder needs it

# 7. Process through sequence encoder
# print(f"[DEBUG FWD 7a] Input to sequence_encoder ({self.arch}) shape: {xseq.shape}")
if self.arch == "trans_enc":
x_encoded = self.sequence_encoder(xseq)
x_encoded = self.sequence_encoder(xseq)
elif self.arch == "gru":
x_encoded, _ = self.sequence_encoder(xseq)
x_encoded, _ = self.sequence_encoder(xseq)
elif self.arch == "trans_dec":
memory = xseq
tgt = xseq
x_encoded = self.sequence_encoder(tgt=tgt, memory=memory)
memory = xseq
tgt = xseq
x_encoded = self.sequence_encoder(tgt=tgt, memory=memory)
else:
raise ValueError("Unsupported architecture")
raise ValueError("Unsupported architecture")
# print(f"[DEBUG FWD 7b] Output from sequence_encoder shape: {x_encoded.shape}")

# 8. Extract the output corresponding to the target fluent_clip
Expand All @@ -249,30 +228,11 @@ def forward(

return output


# def mingyi_forward(
# self, fluent_clip: torch.Tensor, disfluent_seq: torch.Tensor, t: torch.Tensor
# ) -> torch.Tensor:

# batch_size, keypoints, dims, time = fluent_clip.shape

# t_emb = self.embed_timestep(t)
# disfluent_emb = self.disfluent_encoder(disfluent_seq)
# fluent_emb = self.fluent_encoder(fluent_clip)

# xseq = torch.cat((t_emb, disfluent_emb, fluent_emb), axis=0)
# xseq = self.sequence_pos_encoder(xseq)

# x_out = self.sequence_encoder(xseq)[:time]
# output = self.pose_projection(x_out)

# return output

def interface(
self,
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
t: torch.Tensor, # (B,)
y: dict[str, torch.Tensor] # Conditions dict
self,
fluent_clip: torch.Tensor, # (B, K, D, T_chunk)
t: torch.Tensor, # (B,)
y: dict[str, torch.Tensor] # Conditions dict
) -> torch.Tensor:
"""
Interface for Classifier-Free Guidance (CFG). Handles previous_output.
Expand All @@ -283,13 +243,8 @@ def interface(
previous_output = y.get("previous_output", None)

# Apply CFG: randomly drop the condition with probability cond_mask_prob
keep_batch_idx = torch.rand(batch_size, device=disfluent_seq.device) < (1-self.cond_mask_prob)
keep_batch_idx = torch.rand(batch_size, device=disfluent_seq.device) < (1 - self.cond_mask_prob)
disfluent_seq = disfluent_seq * keep_batch_idx.view((batch_size, 1, 1, 1))

# Call the forward function
return self.forward(
fluent_clip=fluent_clip,
disfluent_seq=disfluent_seq,
t=t,
previous_output=previous_output
)
return self.forward(fluent_clip=fluent_clip, disfluent_seq=disfluent_seq, t=t, previous_output=previous_output)
Loading
Loading