@@ -9,22 +9,19 @@ class OutputProcessMLP(nn.Module):
9
9
"""
10
10
Output process for the Sign Language Pose Diffusion model.
11
11
"""
12
- def __init__ (self , input_feats , latent_dim , njoints , nfeats , hidden_dim = 512 ): # add hidden_dim as parameter
12
+
13
+ def __init__ (self , input_feats , latent_dim , njoints , nfeats , hidden_dim = 512 ): # add hidden_dim as parameter
13
14
super ().__init__ ()
14
15
self .input_feats = input_feats
15
16
self .latent_dim = latent_dim
16
17
self .njoints = njoints
17
18
self .nfeats = nfeats
18
- self .hidden_dim = hidden_dim # store hidden dimension
19
+ self .hidden_dim = hidden_dim # store hidden dimension
19
20
20
21
# MLP layers
21
- self .mlp = nn .Sequential (
22
- nn .Linear (self .latent_dim , self .hidden_dim ),
23
- nn .SiLU (),
24
- nn .Linear (self .hidden_dim , self .hidden_dim // 2 ),
25
- nn .SiLU (),
26
- nn .Linear (self .hidden_dim // 2 , self .input_feats )
27
- )
22
+ self .mlp = nn .Sequential (nn .Linear (self .latent_dim , self .hidden_dim ), nn .SiLU (),
23
+ nn .Linear (self .hidden_dim , self .hidden_dim // 2 ), nn .SiLU (),
24
+ nn .Linear (self .hidden_dim // 2 , self .input_feats ))
28
25
29
26
def forward (self , output ):
30
27
nframes , bs , d = output .shape
@@ -39,25 +36,11 @@ class SignLanguagePoseDiffusion(nn.Module):
39
36
Sign Language Pose Diffusion model.
40
37
"""
41
38
42
- def __init__ (
43
- self ,
44
- input_feats : int ,
45
- chunk_len : int ,
46
- keypoints : int ,
47
- dims : int ,
48
- latent_dim : int = 256 ,
49
- ff_size : int = 1024 ,
50
- num_layers : int = 8 ,
51
- num_heads : int = 4 ,
52
- dropout : float = 0.2 ,
53
- ablation : Optional [str ] = None ,
54
- activation : str = "gelu" ,
55
- legacy : bool = False ,
56
- arch : str = "trans_enc" ,
57
- cond_mask_prob : float = 0 ,
58
- device : Optional [torch .device ] = None ,
59
- batch_first : bool = True
60
- ):
39
+ def __init__ (self , input_feats : int , chunk_len : int , keypoints : int , dims : int , latent_dim : int = 256 ,
40
+ ff_size : int = 1024 , num_layers : int = 8 , num_heads : int = 4 , dropout : float = 0.2 ,
41
+ ablation : Optional [str ] = None , activation : str = "gelu" , legacy : bool = False ,
42
+ arch : str = "trans_enc" , cond_mask_prob : float = 0 , device : Optional [torch .device ] = None ,
43
+ batch_first : bool = True ):
61
44
"""
62
45
Args:
63
46
input_feats (int): Number of input features (keypoints * dimensions).
@@ -105,37 +88,33 @@ def __init__(
105
88
# Define sequence encoder based on chosen architecture
106
89
if self .arch == "trans_enc" :
107
90
print (f"Initializing Transformer Encoder (batch_first={ self .batch_first } )" )
108
- encoder_layer = nn .TransformerEncoderLayer (
109
- d_model = latent_dim , nhead = num_heads , dim_feedforward = ff_size ,
110
- dropout = dropout , activation = activation , batch_first = self .batch_first
111
- )
91
+ encoder_layer = nn .TransformerEncoderLayer (d_model = latent_dim , nhead = num_heads , dim_feedforward = ff_size ,
92
+ dropout = dropout , activation = activation ,
93
+ batch_first = self .batch_first )
112
94
self .sequence_encoder = nn .TransformerEncoder (encoder_layer , num_layers = num_layers )
113
95
elif self .arch == "trans_dec" :
114
96
print (f"Initializing Transformer Decoder (batch_first={ self .batch_first } )" )
115
- decoder_layer = nn .TransformerDecoderLayer (
116
- d_model = latent_dim , nhead = num_heads , dim_feedforward = ff_size ,
117
- dropout = dropout , activation = activation , batch_first = self .batch_first
118
- )
97
+ decoder_layer = nn .TransformerDecoderLayer (d_model = latent_dim , nhead = num_heads , dim_feedforward = ff_size ,
98
+ dropout = dropout , activation = activation ,
99
+ batch_first = self .batch_first )
119
100
self .sequence_encoder = nn .TransformerDecoder (decoder_layer , num_layers = num_layers )
120
101
elif self .arch == "gru" :
121
- print ("Initializing GRU Encoder (batch_first=True)" )
122
- self .sequence_encoder = nn .GRU (
123
- latent_dim , latent_dim , num_layers = num_layers , batch_first = True
124
- )
102
+ print ("Initializing GRU Encoder (batch_first=True)" )
103
+ self .sequence_encoder = nn .GRU (latent_dim , latent_dim , num_layers = num_layers , batch_first = True )
125
104
else :
126
105
raise ValueError ("Please choose correct architecture [trans_enc, trans_dec, gru]" )
127
106
128
107
# Pose projection: projects latent representation back to pose space.
129
108
# The OutputProcess returns (B, keypoints, dims, T); apply a post_transform to get (B, T, keypoints, dims)
130
- self .pose_projection = OutputProcessMLP (input_feats , latent_dim , keypoints , dims , hidden_dim = 512 )
109
+ self .pose_projection = OutputProcessMLP (input_feats , latent_dim , keypoints , dims , hidden_dim = 1024 )
131
110
self .to (self .device )
132
111
133
112
def forward (
134
- self ,
135
- fluent_clip : torch .Tensor , # (B, K, D, T_chunk)
136
- disfluent_seq : torch .Tensor , # (B, K, D, T_disfl)
137
- t : torch .Tensor , # (B,)
138
- previous_output : Optional [torch .Tensor ] = None # (B, K, D, T_hist)
113
+ self ,
114
+ fluent_clip : torch .Tensor , # (B, K, D, T_chunk)
115
+ disfluent_seq : torch .Tensor , # (B, K, D, T_disfl)
116
+ t : torch .Tensor , # (B,)
117
+ previous_output : Optional [torch .Tensor ] = None # (B, K, D, T_hist)
139
118
) -> torch .Tensor :
140
119
141
120
# # --- DEBUG: Print Initial Input Shapes ---
@@ -158,35 +137,35 @@ def forward(
158
137
T_chunk = fluent_clip .shape [- 1 ]
159
138
160
139
# 1. Embed Timestep
161
- _t_emb_raw = self .embed_timestep (t ) # Expected (B, D)
140
+ _t_emb_raw = self .embed_timestep (t ) # Expected (B, D)
162
141
# print(f"[DEBUG FWD 1a] Raw t_emb shape: {_t_emb_raw.shape}")
163
- t_emb = _t_emb_raw .permute (1 , 0 , 2 )
142
+ t_emb = _t_emb_raw .permute (1 , 0 , 2 ). contiguous ()
164
143
# print(f"[DEBUG FWD 1b] Final t_emb shape: {t_emb.shape}")
165
144
166
145
# 2. Embed Disfluent Sequence (Condition)
167
- _disfluent_emb_raw = self .disfluent_encoder (disfluent_seq ) # Expected (T_disfl, B, D)
146
+ _disfluent_emb_raw = self .disfluent_encoder (disfluent_seq ) # Expected (T_disfl, B, D)
168
147
# print(f"[DEBUG FWD 2a] Raw disfluent_emb shape: {_disfluent_emb_raw.shape}")
169
- disfluent_emb = _disfluent_emb_raw .permute (1 , 0 , 2 ) # Expected (B, T_disfl, D)
148
+ disfluent_emb = _disfluent_emb_raw .permute (1 , 0 , 2 ). contiguous () # Expected (B, T_disfl, D)
170
149
# print(f"[DEBUG FWD 2b] Final disfluent_emb shape: {disfluent_emb.shape}")
171
150
172
151
# 3. Embed Previous Output (History), if available
173
152
embeddings_to_concat = [t_emb , disfluent_emb ]
174
153
# print("[DEBUG FWD 3a] Processing previous_output...")
175
154
if previous_output is not None and previous_output .shape [- 1 ] > 0 :
176
155
# print(f"[DEBUG FWD 3b] History Input shape: {previous_output.shape}")
177
- _prev_out_emb_raw = self .fluent_encoder (previous_output ) # Expected (T_hist, B, D)
156
+ _prev_out_emb_raw = self .fluent_encoder (previous_output ) # Expected (T_hist, B, D)
178
157
# print(f"[DEBUG FWD 3c] Raw prev_out_emb shape: {_prev_out_emb_raw.shape}")
179
- prev_out_emb = _prev_out_emb_raw .permute (1 , 0 , 2 ) # Expected (B, T_hist, D)
158
+ prev_out_emb = _prev_out_emb_raw .permute (1 , 0 , 2 ). contiguous () # Expected (B, T_hist, D)
180
159
# print(f"[DEBUG FWD 3d] Final prev_out_emb shape: {prev_out_emb.shape}")
181
160
embeddings_to_concat .append (prev_out_emb )
182
161
else :
183
162
# print("[DEBUG FWD 3b] No previous_output provided or it's empty.")
184
163
pass
185
164
186
165
# 4. Embed Current Fluent Clip (Noisy Target 'x')
187
- _fluent_emb_raw = self .fluent_encoder (fluent_clip ) # Expected (T_chunk, B, D)
166
+ _fluent_emb_raw = self .fluent_encoder (fluent_clip ) # Expected (T_chunk, B, D)
188
167
# print(f"[DEBUG FWD 4a] Raw fluent_emb shape: {_fluent_emb_raw.shape}")
189
- fluent_emb = _fluent_emb_raw .permute (1 , 0 , 2 ) # Expected (B, T_chunk, D)
168
+ fluent_emb = _fluent_emb_raw .permute (1 , 0 , 2 ). contiguous () # Expected (B, T_chunk, D)
190
169
# print(f"[DEBUG FWD 4b] Final fluent_emb shape: {fluent_emb.shape}")
191
170
embeddings_to_concat .append (fluent_emb )
192
171
@@ -198,33 +177,33 @@ def forward(
198
177
# print(f"[DEBUG FWD 6a] xseq shape before PositionalEncoding: {xseq.shape}")
199
178
# Adapt based on PositionalEncoding expectation (T, B, D) vs batch_first
200
179
if self .batch_first :
201
- xseq_permuted = xseq .permute (1 , 0 , 2 ) # (T_total, B, D)
180
+ xseq_permuted = xseq .permute (1 , 0 , 2 ). contiguous () # (T_total, B, D)
202
181
# print(f"[DEBUG FWD 6b] xseq permuted for PosEnc: {xseq_permuted.shape}")
203
182
xseq_encoded = self .sequence_pos_encoder (xseq_permuted )
204
183
# print(f"[DEBUG FWD 6c] xseq after PosEnc: {xseq_encoded.shape}")
205
- xseq = xseq_encoded .permute (1 , 0 , 2 ) # Back to (B, T_total, D)
184
+ xseq = xseq_encoded .permute (1 , 0 , 2 ) # Back to (B, T_total, D)
206
185
# print(f"[DEBUG FWD 6d] xseq permuted back: {xseq.shape}")
207
186
else :
208
- # If not batch_first, assume xseq should be (T, B, D) already
209
- # Need to adjust concatenation and permutations above if batch_first=False
210
- xseq = xseq .permute (1 , 0 , 2 ) # Assume needs (T, B, D)
187
+ # If not batch_first, assume xseq should be (T, B, D) already
188
+ # Need to adjust concatenation and permutations above if batch_first=False
189
+ xseq = xseq .permute (1 , 0 , 2 ) # Assume needs (T, B, D)
211
190
# print(f"[DEBUG FWD 6b] xseq permuted for PosEnc (batch_first=False): {xseq.shape}")
212
- xseq = self .sequence_pos_encoder (xseq )
213
- # print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}")
214
- # Keep as (T, B, D) if encoder needs it
191
+ xseq = self .sequence_pos_encoder (xseq )
192
+ # print(f"[DEBUG FWD 6c] xseq after PosEnc (batch_first=False): {xseq.shape}")
193
+ # Keep as (T, B, D) if encoder needs it
215
194
216
195
# 7. Process through sequence encoder
217
196
# print(f"[DEBUG FWD 7a] Input to sequence_encoder ({self.arch}) shape: {xseq.shape}")
218
197
if self .arch == "trans_enc" :
219
- x_encoded = self .sequence_encoder (xseq )
198
+ x_encoded = self .sequence_encoder (xseq )
220
199
elif self .arch == "gru" :
221
- x_encoded , _ = self .sequence_encoder (xseq )
200
+ x_encoded , _ = self .sequence_encoder (xseq )
222
201
elif self .arch == "trans_dec" :
223
- memory = xseq
224
- tgt = xseq
225
- x_encoded = self .sequence_encoder (tgt = tgt , memory = memory )
202
+ memory = xseq
203
+ tgt = xseq
204
+ x_encoded = self .sequence_encoder (tgt = tgt , memory = memory )
226
205
else :
227
- raise ValueError ("Unsupported architecture" )
206
+ raise ValueError ("Unsupported architecture" )
228
207
# print(f"[DEBUG FWD 7b] Output from sequence_encoder shape: {x_encoded.shape}")
229
208
230
209
# 8. Extract the output corresponding to the target fluent_clip
@@ -249,30 +228,11 @@ def forward(
249
228
250
229
return output
251
230
252
-
253
- # def mingyi_forward(
254
- # self, fluent_clip: torch.Tensor, disfluent_seq: torch.Tensor, t: torch.Tensor
255
- # ) -> torch.Tensor:
256
-
257
- # batch_size, keypoints, dims, time = fluent_clip.shape
258
-
259
- # t_emb = self.embed_timestep(t)
260
- # disfluent_emb = self.disfluent_encoder(disfluent_seq)
261
- # fluent_emb = self.fluent_encoder(fluent_clip)
262
-
263
- # xseq = torch.cat((t_emb, disfluent_emb, fluent_emb), axis=0)
264
- # xseq = self.sequence_pos_encoder(xseq)
265
-
266
- # x_out = self.sequence_encoder(xseq)[:time]
267
- # output = self.pose_projection(x_out)
268
-
269
- # return output
270
-
271
231
def interface (
272
- self ,
273
- fluent_clip : torch .Tensor , # (B, K, D, T_chunk)
274
- t : torch .Tensor , # (B,)
275
- y : dict [str , torch .Tensor ] # Conditions dict
232
+ self ,
233
+ fluent_clip : torch .Tensor , # (B, K, D, T_chunk)
234
+ t : torch .Tensor , # (B,)
235
+ y : dict [str , torch .Tensor ] # Conditions dict
276
236
) -> torch .Tensor :
277
237
"""
278
238
Interface for Classifier-Free Guidance (CFG). Handles previous_output.
@@ -283,13 +243,8 @@ def interface(
283
243
previous_output = y .get ("previous_output" , None )
284
244
285
245
# Apply CFG: randomly drop the condition with probability cond_mask_prob
286
- keep_batch_idx = torch .rand (batch_size , device = disfluent_seq .device ) < (1 - self .cond_mask_prob )
246
+ keep_batch_idx = torch .rand (batch_size , device = disfluent_seq .device ) < (1 - self .cond_mask_prob )
287
247
disfluent_seq = disfluent_seq * keep_batch_idx .view ((batch_size , 1 , 1 , 1 ))
288
248
289
249
# Call the forward function
290
- return self .forward (
291
- fluent_clip = fluent_clip ,
292
- disfluent_seq = disfluent_seq ,
293
- t = t ,
294
- previous_output = previous_output
295
- )
250
+ return self .forward (fluent_clip = fluent_clip , disfluent_seq = disfluent_seq , t = t , previous_output = previous_output )
0 commit comments