@@ -116,6 +116,20 @@ def pad_to_length(t, length, dim = -1, value = 0, right = True):
116
116
padding = (0 , remainder ) if right else (remainder , 0 )
117
117
return pad_at_dim (t , padding , dim = dim , value = value )
118
118
119
+ def masked_mean (tensor , mask , dim = - 1 , eps = 1e-5 ):
120
+ if not exists (mask ):
121
+ return tensor .mean (dim = dim )
122
+
123
+ mask = rearrange (mask , '... -> ... 1' )
124
+ tensor = tensor .masked_fill (~ mask , 0. )
125
+
126
+ total_el = mask .sum (dim = dim )
127
+ num = tensor .sum (dim = dim )
128
+ den = total_el .float ().clamp (min = eps )
129
+ mean = num / den
130
+ mean = mean .masked_fill (total_el == 0 , 0. )
131
+ return mean
132
+
119
133
# continuous embed
120
134
121
135
def ContinuousEmbed (dim_cont ):
@@ -1039,9 +1053,14 @@ def __init__(
1039
1053
self .codebook_size = autoencoder .codebook_size
1040
1054
self .num_quantizers = autoencoder .num_quantizers
1041
1055
1042
- self .sos_token = nn .Parameter (torch .randn (dim_fine ))
1043
1056
self .eos_token_id = self .codebook_size
1044
1057
1058
+ # the fine transformer sos token
1059
+ # as well as a projection of pooled text embeddings to condition it
1060
+ # (todo) - sos token should be moved to the coarse transformer stage
1061
+
1062
+ self .sos_token = nn .Parameter (torch .randn (dim_fine ))
1063
+
1045
1064
# they use axial positional embeddings
1046
1065
1047
1066
assert divisible_by (max_seq_len , self .num_vertices_per_face * self .num_quantizers ), f'max_seq_len ({ max_seq_len } ) must be divisible by (3 x { self .num_quantizers } ) = { 3 * self .num_quantizers } ' # 3 or 4 vertices per face, with D codes per vertex
@@ -1067,7 +1086,11 @@ def __init__(
1067
1086
model_types = text_condition_model_types ,
1068
1087
cond_drop_prob = text_condition_cond_drop_prob
1069
1088
)
1070
- cross_attn_dim_context = self .conditioner .dim_latent
1089
+
1090
+ dim_text = self .conditioner .dim_latent
1091
+ cross_attn_dim_context = dim_text
1092
+
1093
+ self .to_sos_text_cond = nn .Linear (dim_text , dim_fine )
1071
1094
1072
1095
# for summarizing the vertices of each face
1073
1096
@@ -1437,10 +1460,24 @@ def forward_on_codes(
1437
1460
1438
1461
attended_face_codes = self .maybe_project_coarse_to_fine (attended_face_codes )
1439
1462
1440
- # auto prepend sos token
1463
+ # repeat sos token across batch
1441
1464
1442
1465
sos = repeat (self .sos_token , 'd -> b d' , b = batch )
1443
1466
1467
+ # condition sos token if needed
1468
+
1469
+ if self .condition_on_text :
1470
+ pooled_text_embed = masked_mean (
1471
+ maybe_dropped_text_embeds .embed ,
1472
+ maybe_dropped_text_embeds .mask ,
1473
+ dim = 1
1474
+ )
1475
+
1476
+ sos_cond = self .to_sos_text_cond (pooled_text_embed )
1477
+ sos = sos + sos_cond
1478
+
1479
+ # auto prepend sos token
1480
+
1444
1481
attended_face_codes_with_sos , _ = pack ([sos , attended_face_codes ], 'b * d' )
1445
1482
1446
1483
grouped_codes = pad_to_length (grouped_codes , attended_face_codes_with_sos .shape [- 2 ], dim = 1 )
0 commit comments