Skip to content

Commit 443ebc7

Browse files
committed
add a test for caching
1 parent 36d8dba commit 443ebc7

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,6 +1157,7 @@ def __init__(
11571157
self.conditioner = None
11581158

11591159
cross_attn_dim_context = None
1160+
dim_text = None
11601161

11611162
if condition_on_text:
11621163
self.conditioner = TextEmbeddingReturner(

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.1'
1+
__version__ = '1.5.2'

tests/test_meshgpt.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_readme(adaptive_rmsnorm):
1616
# mock inputs
1717

1818
vertices = torch.randn((2, 121, 3)) # (batch, num vertices, coor (3))
19-
faces = torch.randint(0, 121, (2, 64, 3)) # (batch, num faces, vertices (3))
19+
faces = torch.randint(0, 121, (2, 2, 3)) # (batch, num faces, vertices (3))
2020

2121
# forward in the faces
2222

@@ -33,7 +33,7 @@ def test_readme(adaptive_rmsnorm):
3333
transformer = MeshTransformer(
3434
autoencoder,
3535
dim = 512,
36-
max_seq_len = 768,
36+
max_seq_len = 60,
3737
num_sos_tokens = 1,
3838
fine_cross_attend_text = True,
3939
text_cond_with_film = False,
@@ -51,3 +51,21 @@ def test_readme(adaptive_rmsnorm):
5151
loss.backward()
5252

5353
faces_coordinates, face_mask = transformer.generate(texts = ['a small chair'], cond_scale = 3.)
54+
55+
def test_cache():
56+
# test that the output for generation with and without kv (and optional gateloop) cache is equivalent
57+
58+
autoencoder = MeshAutoencoder(
59+
num_discrete_coors = 128
60+
)
61+
62+
transformer = MeshTransformer(
63+
autoencoder,
64+
dim = 512,
65+
max_seq_len = 12
66+
)
67+
68+
uncached_faces_coors, _ = transformer.generate(cache_kv = False, temperature = 0)
69+
cached_faces_coors, _ = transformer.generate(cache_kv = True, temperature = 0)
70+
71+
assert torch.allclose(uncached_faces_coors, cached_faces_coors)

0 commit comments

Comments
 (0)