@@ -16,7 +16,7 @@ def test_readme(adaptive_rmsnorm):
16
16
# mock inputs
17
17
18
18
vertices = torch .randn ((2 , 121 , 3 )) # (batch, num vertices, coor (3))
19
- faces = torch .randint (0 , 121 , (2 , 64 , 3 )) # (batch, num faces, vertices (3))
19
+ faces = torch .randint (0 , 121 , (2 , 2 , 3 )) # (batch, num faces, vertices (3))
20
20
21
21
# forward in the faces
22
22
@@ -33,7 +33,7 @@ def test_readme(adaptive_rmsnorm):
33
33
transformer = MeshTransformer (
34
34
autoencoder ,
35
35
dim = 512 ,
36
- max_seq_len = 768 ,
36
+ max_seq_len = 60 ,
37
37
num_sos_tokens = 1 ,
38
38
fine_cross_attend_text = True ,
39
39
text_cond_with_film = False ,
@@ -51,3 +51,21 @@ def test_readme(adaptive_rmsnorm):
51
51
loss .backward ()
52
52
53
53
faces_coordinates , face_mask = transformer .generate (texts = ['a small chair' ], cond_scale = 3. )
54
+
55
+ def test_cache ():
56
+ # test that the output for generation with and without kv (and optional gateloop) cache is equivalent
57
+
58
+ autoencoder = MeshAutoencoder (
59
+ num_discrete_coors = 128
60
+ )
61
+
62
+ transformer = MeshTransformer (
63
+ autoencoder ,
64
+ dim = 512 ,
65
+ max_seq_len = 12
66
+ )
67
+
68
+ uncached_faces_coors , _ = transformer .generate (cache_kv = False , temperature = 0 )
69
+ cached_faces_coors , _ = transformer .generate (cache_kv = True , temperature = 0 )
70
+
71
+ assert torch .allclose (uncached_faces_coors , cached_faces_coors )
0 commit comments