@@ -1139,6 +1139,9 @@ def __init__(
1139
1139
attn_flash = flash_attn ,
1140
1140
attn_dropout = dropout ,
1141
1141
ff_dropout = dropout ,
1142
+ cross_attend = condition_on_text ,
1143
+ cross_attn_dim_context = cross_attn_dim_context ,
1144
+ cross_attn_num_mem_kv = cross_attn_num_mem_kv ,
1142
1145
** attn_kwargs
1143
1146
)
1144
1147
@@ -1338,9 +1341,11 @@ def forward_on_codes(
1338
1341
cond_drop_prob = cond_drop_prob
1339
1342
)
1340
1343
1344
+ text_embed , text_mask = maybe_dropped_text_embeds
1345
+
1341
1346
attn_context_kwargs = dict (
1342
- context = maybe_dropped_text_embeds . embed ,
1343
- context_mask = maybe_dropped_text_embeds . mask
1347
+ context = text_embed ,
1348
+ context_mask = text_mask
1344
1349
)
1345
1350
1346
1351
# take care of codes that may be flattened
@@ -1471,8 +1476,8 @@ def forward_on_codes(
1471
1476
1472
1477
if self .condition_on_text :
1473
1478
pooled_text_embed = masked_mean (
1474
- maybe_dropped_text_embeds . embed ,
1475
- maybe_dropped_text_embeds . mask ,
1479
+ text_embed ,
1480
+ text_mask ,
1476
1481
dim = 1
1477
1482
)
1478
1483
@@ -1512,15 +1517,25 @@ def forward_on_codes(
1512
1517
ck , cv = map (lambda t : t [:, - 1 , :, :curr_vertex_pos ], (ck , cv ))
1513
1518
attn_intermediate .cached_kv = (ck , cv )
1514
1519
1515
- one_face = fine_vertex_codes .shape [1 ] == 1
1520
+ num_faces = fine_vertex_codes .shape [1 ]
1521
+ one_face = num_faces == 1
1516
1522
1517
1523
fine_vertex_codes = rearrange (fine_vertex_codes , 'b nf n d -> (b nf) n d' )
1518
1524
1519
1525
if one_face :
1520
1526
fine_vertex_codes = fine_vertex_codes [:, :(curr_vertex_pos + 1 )]
1521
1527
1528
+ fine_attn_context_kwargs = dict ()
1529
+
1530
+ if self .condition_on_text :
1531
+ fine_attn_context_kwargs = dict (
1532
+ context = repeat (text_embed , 'b ... -> (b nf) ...' , nf = num_faces ),
1533
+ context_mask = repeat (text_mask , 'b ... -> (b nf) ...' , nf = num_faces )
1534
+ )
1535
+
1522
1536
attended_vertex_codes , fine_cache = self .fine_decoder (
1523
1537
fine_vertex_codes ,
1538
+ ** fine_attn_context_kwargs ,
1524
1539
cache = fine_cache ,
1525
1540
return_hiddens = True
1526
1541
)
0 commit comments