Skip to content

Commit 5ef6cbf

Browse files
committed
add cross attention based text conditioning for fine transformer too
1 parent 8187f8d commit 5ef6cbf

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,9 @@ def __init__(
11391139
attn_flash = flash_attn,
11401140
attn_dropout = dropout,
11411141
ff_dropout = dropout,
1142+
cross_attend = condition_on_text,
1143+
cross_attn_dim_context = cross_attn_dim_context,
1144+
cross_attn_num_mem_kv = cross_attn_num_mem_kv,
11421145
**attn_kwargs
11431146
)
11441147

@@ -1338,9 +1341,11 @@ def forward_on_codes(
13381341
cond_drop_prob = cond_drop_prob
13391342
)
13401343

1344+
text_embed, text_mask = maybe_dropped_text_embeds
1345+
13411346
attn_context_kwargs = dict(
1342-
context = maybe_dropped_text_embeds.embed,
1343-
context_mask = maybe_dropped_text_embeds.mask
1347+
context = text_embed,
1348+
context_mask = text_mask
13441349
)
13451350

13461351
# take care of codes that may be flattened
@@ -1471,8 +1476,8 @@ def forward_on_codes(
14711476

14721477
if self.condition_on_text:
14731478
pooled_text_embed = masked_mean(
1474-
maybe_dropped_text_embeds.embed,
1475-
maybe_dropped_text_embeds.mask,
1479+
text_embed,
1480+
text_mask,
14761481
dim = 1
14771482
)
14781483

@@ -1512,15 +1517,25 @@ def forward_on_codes(
15121517
ck, cv = map(lambda t: t[:, -1, :, :curr_vertex_pos], (ck, cv))
15131518
attn_intermediate.cached_kv = (ck, cv)
15141519

1515-
one_face = fine_vertex_codes.shape[1] == 1
1520+
num_faces = fine_vertex_codes.shape[1]
1521+
one_face = num_faces == 1
15161522

15171523
fine_vertex_codes = rearrange(fine_vertex_codes, 'b nf n d -> (b nf) n d')
15181524

15191525
if one_face:
15201526
fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]
15211527

1528+
fine_attn_context_kwargs = dict()
1529+
1530+
if self.condition_on_text:
1531+
fine_attn_context_kwargs = dict(
1532+
context = repeat(text_embed, 'b ... -> (b nf) ...', nf = num_faces),
1533+
context_mask = repeat(text_mask, 'b ... -> (b nf) ...', nf = num_faces)
1534+
)
1535+
15221536
attended_vertex_codes, fine_cache = self.fine_decoder(
15231537
fine_vertex_codes,
1538+
**fine_attn_context_kwargs,
15241539
cache = fine_cache,
15251540
return_hiddens = True
15261541
)

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.2'
1+
__version__ = '1.2.3'

0 commit comments

Comments
 (0)