Skip to content

Commit b5e8d10

Browse files
committed
address first set of fine tokens not being conditioned #80
1 parent 7e58dd3 commit b5e8d10

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,20 @@ def pad_to_length(t, length, dim = -1, value = 0, right = True):
116116
padding = (0, remainder) if right else (remainder, 0)
117117
return pad_at_dim(t, padding, dim = dim, value = value)
118118

119+
def masked_mean(tensor, mask, dim = -1, eps = 1e-5):
120+
if not exists(mask):
121+
return tensor.mean(dim = dim)
122+
123+
mask = rearrange(mask, '... -> ... 1')
124+
tensor = tensor.masked_fill(~mask, 0.)
125+
126+
total_el = mask.sum(dim = dim)
127+
num = tensor.sum(dim = dim)
128+
den = total_el.float().clamp(min = eps)
129+
mean = num / den
130+
mean = mean.masked_fill(total_el == 0, 0.)
131+
return mean
132+
119133
# continuous embed
120134

121135
def ContinuousEmbed(dim_cont):
@@ -1039,9 +1053,14 @@ def __init__(
10391053
self.codebook_size = autoencoder.codebook_size
10401054
self.num_quantizers = autoencoder.num_quantizers
10411055

1042-
self.sos_token = nn.Parameter(torch.randn(dim_fine))
10431056
self.eos_token_id = self.codebook_size
10441057

1058+
# the fine transformer sos token
1059+
# as well as a projection of pooled text embeddings to condition it
1060+
# (todo) - sos token should be moved to the coarse transformer stage
1061+
1062+
self.sos_token = nn.Parameter(torch.randn(dim_fine))
1063+
10451064
# they use axial positional embeddings
10461065

10471066
assert divisible_by(max_seq_len, self.num_vertices_per_face * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 or 4 vertices per face, with D codes per vertex
@@ -1067,7 +1086,11 @@ def __init__(
10671086
model_types = text_condition_model_types,
10681087
cond_drop_prob = text_condition_cond_drop_prob
10691088
)
1070-
cross_attn_dim_context = self.conditioner.dim_latent
1089+
1090+
dim_text = self.conditioner.dim_latent
1091+
cross_attn_dim_context = dim_text
1092+
1093+
self.to_sos_text_cond = nn.Linear(dim_text, dim_fine)
10711094

10721095
# for summarizing the vertices of each face
10731096

@@ -1437,10 +1460,24 @@ def forward_on_codes(
14371460

14381461
attended_face_codes = self.maybe_project_coarse_to_fine(attended_face_codes)
14391462

1440-
# auto prepend sos token
1463+
# repeat sos token across batch
14411464

14421465
sos = repeat(self.sos_token, 'd -> b d', b = batch)
14431466

1467+
# condition sos token if needed
1468+
1469+
if self.condition_on_text:
1470+
pooled_text_embed = masked_mean(
1471+
maybe_dropped_text_embeds.embed,
1472+
maybe_dropped_text_embeds.mask,
1473+
dim = 1
1474+
)
1475+
1476+
sos_cond = self.to_sos_text_cond(pooled_text_embed)
1477+
sos = sos + sos_cond
1478+
1479+
# auto prepend sos token
1480+
14441481
attended_face_codes_with_sos, _ = pack([sos, attended_face_codes], 'b * d')
14451482

14461483
grouped_codes = pad_to_length(grouped_codes, attended_face_codes_with_sos.shape[-2], dim = 1)

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.1.1'
1+
__version__ = '1.2.0'

0 commit comments

Comments
 (0)