@@ -1075,7 +1075,7 @@ def forward(
1075
1075
return recon_faces , total_loss , loss_breakdown
1076
1076
1077
1077
@save_load (version = __version__ )
1078
- class MeshTransformer (Module ,PyTorchModelHubMixin ):
1078
+ class MeshTransformer (Module , PyTorchModelHubMixin ):
1079
1079
@typecheck
1080
1080
def __init__ (
1081
1081
self ,
@@ -1094,12 +1094,13 @@ def __init__(
1094
1094
cross_attn_num_mem_kv = 4 , # needed for preventing nan when dropping out text condition
1095
1095
dropout = 0. ,
1096
1096
coarse_pre_gateloop_depth = 2 ,
1097
+ coarse_adaptive_rmsnorm = False ,
1097
1098
fine_pre_gateloop_depth = 2 ,
1098
1099
gateloop_use_heinsen = False ,
1099
1100
fine_attn_depth = 2 ,
1100
1101
fine_attn_dim_head = 32 ,
1101
1102
fine_attn_heads = 8 ,
1102
- fine_cross_attend_text = False ,
1103
+ fine_cross_attend_text = False , # additional conditioning - fine transformer cross attention to text tokens
1103
1104
pad_id = - 1 ,
1104
1105
num_sos_tokens = None ,
1105
1106
condition_on_text = False ,
@@ -1177,6 +1178,8 @@ def __init__(
1177
1178
# main autoregressive attention network
1178
1179
# attending to a face token
1179
1180
1181
+ self .coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm
1182
+
1180
1183
self .decoder = Decoder (
1181
1184
dim = dim ,
1182
1185
depth = attn_depth ,
@@ -1185,6 +1188,8 @@ def __init__(
1185
1188
attn_flash = flash_attn ,
1186
1189
attn_dropout = dropout ,
1187
1190
ff_dropout = dropout ,
1191
+ use_adaptive_rmsnorm = coarse_adaptive_rmsnorm ,
1192
+ dim_condition = dim_text ,
1188
1193
cross_attend = condition_on_text ,
1189
1194
cross_attn_dim_context = cross_attn_dim_context ,
1190
1195
cross_attn_num_mem_kv = cross_attn_num_mem_kv ,
@@ -1458,6 +1463,11 @@ def forward_on_codes(
1458
1463
context_mask = text_mask
1459
1464
)
1460
1465
1466
+ if self .coarse_adaptive_rmsnorm :
1467
+ attn_context_kwargs .update (
1468
+ condition = pooled_text_embed
1469
+ )
1470
+
1461
1471
# take care of codes that may be flattened
1462
1472
1463
1473
if codes .ndim > 2 :
0 commit comments