@@ -1038,10 +1038,11 @@ def __init__(
1038
1038
fine_attn_dim_head = 32 ,
1039
1039
fine_attn_heads = 8 ,
1040
1040
pad_id = - 1 ,
1041
+ num_sos_tokens = 1 ,
1041
1042
condition_on_text = False ,
1042
1043
text_condition_model_types = ('t5' ,),
1043
1044
text_condition_cond_drop_prob = 0.25 ,
1044
- quads = False
1045
+ quads = False ,
1045
1046
):
1046
1047
super ().__init__ ()
1047
1048
self .num_vertices_per_face = 3 if not quads else 4
@@ -1060,9 +1061,11 @@ def __init__(
1060
1061
1061
1062
# the fine transformer sos token
1062
1063
# as well as a projection of pooled text embeddings to condition it
1063
- # (todo) - sos token should be moved to the coarse transformer stage
1064
1064
1065
- self .sos_token = nn .Parameter (torch .randn (dim_fine ))
1065
+ assert num_sos_tokens > 0
1066
+
1067
+ self .num_sos_tokens = num_sos_tokens
1068
+ self .sos_token = nn .Parameter (torch .randn (num_sos_tokens , dim ))
1066
1069
1067
1070
# they use axial positional embeddings
1068
1071
@@ -1444,8 +1447,8 @@ def forward_on_codes(
1444
1447
else :
1445
1448
# auto prepend sos token
1446
1449
1447
- sos = repeat (self .sos_token , 'd -> b d' , b = batch )
1448
- face_codes , _ = pack ([sos , face_codes ], 'b * d' )
1450
+ sos = repeat (self .sos_token , 'n d -> b n d' , b = batch )
1451
+ face_codes , packed_sos_shape = pack ([sos , face_codes ], 'b * d' )
1449
1452
1450
1453
# if no kv cache, always call first transformer
1451
1454
@@ -1470,6 +1473,13 @@ def forward_on_codes(
1470
1473
1471
1474
attended_face_codes = safe_cat ((cached_attended_face_codes , attended_face_codes ), dim = - 2 )
1472
1475
1476
+ # if calling without kv cache, pool the sos tokens, if greater than 1 sos token
1477
+
1478
+ if not exists (cache ) and self .num_sos_tokens > 1 :
1479
+ sos_tokens , attended_face_codes = unpack (attended_face_codes , packed_sos_shape , 'b * d' )
1480
+ pooled_sos_token = reduce (sos_tokens , 'b n d -> b 1 d' , 'mean' )
1481
+ attended_face_codes = torch .cat ((pooled_sos_token , attended_face_codes ), dim = 1 )
1482
+
1473
1483
# maybe project from coarse to fine dimension for hierarchical transformers
1474
1484
1475
1485
attended_face_codes = self .maybe_project_coarse_to_fine (attended_face_codes )
0 commit comments