Skip to content

Commit f1d392d

Browse files
committed
deal this issue #80 yet another blow, by allowing for multiple sos tokens, pooled before fine transformer
1 parent 34f2806 commit f1d392d

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,10 +1038,11 @@ def __init__(
10381038
fine_attn_dim_head = 32,
10391039
fine_attn_heads = 8,
10401040
pad_id = -1,
1041+
num_sos_tokens = 1,
10411042
condition_on_text = False,
10421043
text_condition_model_types = ('t5',),
10431044
text_condition_cond_drop_prob = 0.25,
1044-
quads = False
1045+
quads = False,
10451046
):
10461047
super().__init__()
10471048
self.num_vertices_per_face = 3 if not quads else 4
@@ -1060,9 +1061,11 @@ def __init__(
10601061

10611062
# the fine transformer sos token
10621063
# as well as a projection of pooled text embeddings to condition it
1063-
# (todo) - sos token should be moved to the coarse transformer stage
10641064

1065-
self.sos_token = nn.Parameter(torch.randn(dim_fine))
1065+
assert num_sos_tokens > 0
1066+
1067+
self.num_sos_tokens = num_sos_tokens
1068+
self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim))
10661069

10671070
# they use axial positional embeddings
10681071

@@ -1444,8 +1447,8 @@ def forward_on_codes(
14441447
else:
14451448
# auto prepend sos token
14461449

1447-
sos = repeat(self.sos_token, 'd -> b d', b = batch)
1448-
face_codes, _ = pack([sos, face_codes], 'b * d')
1450+
sos = repeat(self.sos_token, 'n d -> b n d', b = batch)
1451+
face_codes, packed_sos_shape = pack([sos, face_codes], 'b * d')
14491452

14501453
# if no kv cache, always call first transformer
14511454

@@ -1470,6 +1473,13 @@ def forward_on_codes(
14701473

14711474
attended_face_codes = safe_cat((cached_attended_face_codes, attended_face_codes), dim = -2)
14721475

1476+
# if calling without kv cache, pool the sos tokens, if greater than 1 sos token
1477+
1478+
if not exists(cache) and self.num_sos_tokens > 1:
1479+
sos_tokens, attended_face_codes = unpack(attended_face_codes, packed_sos_shape, 'b * d')
1480+
pooled_sos_token = reduce(sos_tokens, 'b n d -> b 1 d', 'mean')
1481+
attended_face_codes = torch.cat((pooled_sos_token, attended_face_codes), dim = 1)
1482+
14731483
# maybe project from coarse to fine dimension for hierarchical transformers
14741484

14751485
attended_face_codes = self.maybe_project_coarse_to_fine(attended_face_codes)

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.4'
1+
__version__ = '1.2.5'

0 commit comments

Comments
 (0)