11
11
from torch .utils .checkpoint import checkpoint
12
12
from torch .cuda .amp import autocast
13
13
14
- from torchtyping import TensorType
15
-
16
14
from pytorch_custom_utils import save_load
17
15
18
- from beartype import beartype
19
16
from beartype .typing import Tuple , Callable , List , Dict , Any
17
+ from meshgpt_pytorch .typing import Float , Int , Bool , typecheck
20
18
21
19
from huggingface_hub import PyTorchModelHubMixin , hf_hub_download
22
20
26
24
from einx import get_at
27
25
28
26
from x_transformers import Decoder
29
- from x_transformers .attend import Attend
30
27
from x_transformers .x_transformers import RMSNorm , FeedForward , LayerIntermediates
31
28
32
29
from x_transformers .autoregressive_wrapper import (
@@ -78,8 +75,8 @@ def divisible_by(num, den):
78
75
def is_odd (n ):
79
76
return not divisible_by (n , 2 )
80
77
81
- def is_empty (l ):
82
- return len (l ) == 0
78
+ def is_empty (x ):
79
+ return len (x ) == 0
83
80
84
81
def is_tensor_empty (t : Tensor ):
85
82
return t .numel () == 0
@@ -157,7 +154,7 @@ def derive_angle(x, y, eps = 1e-5):
157
154
158
155
@torch .no_grad ()
159
156
def get_derived_face_features (
160
- face_coords : TensorType ['b' , 'nf' , ' nvf' , 3 , float ] # 3 or 4 vertices with 3 coordinates
157
+ face_coords : Float ['b nf nvf 3' ] # 3 or 4 vertices with 3 coordinates
161
158
):
162
159
shifted_face_coords = torch .cat ((face_coords [:, :, - 1 :], face_coords [:, :, :- 1 ]), dim = 2 )
163
160
@@ -178,7 +175,7 @@ def get_derived_face_features(
178
175
179
176
# tensor helper functions
180
177
181
- @beartype
178
+ @typecheck
182
179
def discretize (
183
180
t : Tensor ,
184
181
* ,
@@ -194,7 +191,7 @@ def discretize(
194
191
195
192
return t .round ().long ().clamp (min = 0 , max = num_discrete - 1 )
196
193
197
- @beartype
194
+ @typecheck
198
195
def undiscretize (
199
196
t : Tensor ,
200
197
* ,
@@ -210,7 +207,7 @@ def undiscretize(
210
207
t /= num_discrete
211
208
return t * (hi - lo ) + lo
212
209
213
- @beartype
210
+ @typecheck
214
211
def gaussian_blur_1d (
215
212
t : Tensor ,
216
213
* ,
@@ -234,7 +231,7 @@ def gaussian_blur_1d(
234
231
out = F .conv1d (t , kernel , padding = half_width , groups = channels )
235
232
return rearrange (out , 'b c n -> b n c' )
236
233
237
- @beartype
234
+ @typecheck
238
235
def scatter_mean (
239
236
tgt : Tensor ,
240
237
indices : Tensor ,
@@ -421,7 +418,7 @@ def forward(
421
418
422
419
@save_load (version = __version__ )
423
420
class MeshAutoencoder (Module ):
424
- @beartype
421
+ @typecheck
425
422
def __init__ (
426
423
self ,
427
424
num_discrete_coors = 128 ,
@@ -671,15 +668,15 @@ def _from_pretrained(
671
668
model .to (map_location )
672
669
return model
673
670
674
- @beartype
671
+ @typecheck
675
672
def encode (
676
673
self ,
677
674
* ,
678
- vertices : TensorType ['b' , 'nv' , 3 , float ],
679
- faces : TensorType ['b' , 'nf' , ' nvf', int ],
680
- face_edges : TensorType ['b' , 'e' , 2 , int ],
681
- face_mask : TensorType ['b' , ' nf', bool ],
682
- face_edges_mask : TensorType ['b' , 'e' , bool ],
675
+ vertices : Float ['b nv 3' ],
676
+ faces : Int ['b nf nvf' ],
677
+ face_edges : Int ['b e 2' ],
678
+ face_mask : Bool ['b nf' ],
679
+ face_edges_mask : Bool ['b e' ],
683
680
return_face_coordinates = False
684
681
):
685
682
"""
@@ -692,7 +689,6 @@ def encode(
692
689
d - embed dim
693
690
"""
694
691
695
- batch , num_vertices , num_coors , device = * vertices .shape , vertices .device
696
692
_ , num_faces , num_vertices_per_face = faces .shape
697
693
698
694
assert self .num_vertices_per_face == num_vertices_per_face
@@ -773,18 +769,18 @@ def encode(
773
769
774
770
return face_embed , discrete_face_coords
775
771
776
- @beartype
772
+ @typecheck
777
773
def quantize (
778
774
self ,
779
775
* ,
780
- faces : TensorType ['b' , 'nf' , ' nvf', int ],
781
- face_mask : TensorType ['b' , 'n' , bool ],
782
- face_embed : TensorType ['b' , 'nf' , 'd' , float ],
776
+ faces : Int ['b nf nvf' ],
777
+ face_mask : Bool ['b n' ],
778
+ face_embed : Float ['b nf d' ],
783
779
pad_id = None ,
784
780
rvq_sample_codebook_temp = 1.
785
781
):
786
782
pad_id = default (pad_id , self .pad_id )
787
- batch , num_faces , device = * faces .shape [: 2 ], faces .device
783
+ batch , device = faces .shape [0 ], faces .device
788
784
789
785
max_vertex_index = faces .amax ()
790
786
num_vertices = int (max_vertex_index .item () + 1 )
@@ -858,11 +854,11 @@ def quantize_wrapper_fn(inp):
858
854
859
855
return face_embed_output , codes_output , commit_loss
860
856
861
- @beartype
857
+ @typecheck
862
858
def decode (
863
859
self ,
864
- quantized : TensorType ['b' , 'n' , 'd' , float ],
865
- face_mask : TensorType ['b' , 'n' , bool ]
860
+ quantized : Float ['b n d' ],
861
+ face_mask : Bool ['b n' ]
866
862
):
867
863
conv_face_mask = rearrange (face_mask , 'b n -> b 1 n' )
868
864
@@ -884,12 +880,12 @@ def decode(
884
880
885
881
return rearrange (x , 'b d n -> b n d' )
886
882
887
- @beartype
883
+ @typecheck
888
884
@torch .no_grad ()
889
885
def decode_from_codes_to_faces (
890
886
self ,
891
887
codes : Tensor ,
892
- face_mask : TensorType ['b' , 'n' , bool ] | None = None ,
888
+ face_mask : Bool ['b n' ] | None = None ,
893
889
return_discrete_codes = False
894
890
):
895
891
codes = rearrange (codes , 'b ... -> b (...)' )
@@ -964,13 +960,13 @@ def tokenize(self, vertices, faces, face_edges = None, **kwargs):
964
960
965
961
return codes
966
962
967
- @beartype
963
+ @typecheck
968
964
def forward (
969
965
self ,
970
966
* ,
971
- vertices : TensorType ['b' , 'nv' , 3 , float ],
972
- faces : TensorType ['b' , 'nf' , ' nvf', int ],
973
- face_edges : TensorType ['b' , 'e' , 2 , int ] | None = None ,
967
+ vertices : Float ['b nv 3' ],
968
+ faces : Int ['b nf nvf' ],
969
+ face_edges : Int ['b e 2' ] | None = None ,
974
970
return_codes = False ,
975
971
return_loss_breakdown = False ,
976
972
return_recon_faces = False ,
@@ -980,7 +976,7 @@ def forward(
980
976
if not exists (face_edges ):
981
977
face_edges = derive_face_edges_from_faces (faces , pad_id = self .pad_id )
982
978
983
- num_faces , num_face_edges , device = faces . shape [ 1 ], face_edges . shape [ 1 ], faces .device
979
+ device = faces .device
984
980
985
981
face_mask = reduce (faces != self .pad_id , 'b nf c -> b nf' , 'all' )
986
982
face_edges_mask = reduce (face_edges != self .pad_id , 'b e ij -> b e' , 'all' )
@@ -1079,7 +1075,7 @@ def forward(
1079
1075
1080
1076
@save_load (version = __version__ )
1081
1077
class MeshTransformer (Module ,PyTorchModelHubMixin ):
1082
- @beartype
1078
+ @typecheck
1083
1079
def __init__ (
1084
1080
self ,
1085
1081
autoencoder : MeshAutoencoder ,
@@ -1270,7 +1266,7 @@ def _from_pretrained(
1270
1266
def device (self ):
1271
1267
return next (self .parameters ()).device
1272
1268
1273
- @beartype
1269
+ @typecheck
1274
1270
@torch .no_grad ()
1275
1271
def embed_texts (self , texts : str | List [str ]):
1276
1272
single_text = not isinstance (texts , list )
@@ -1287,7 +1283,7 @@ def embed_texts(self, texts: str | List[str]):
1287
1283
1288
1284
@eval_decorator
1289
1285
@torch .no_grad ()
1290
- @beartype
1286
+ @typecheck
1291
1287
def generate (
1292
1288
self ,
1293
1289
prompt : Tensor | None = None ,
@@ -1406,9 +1402,9 @@ def generate(
1406
1402
def forward (
1407
1403
self ,
1408
1404
* ,
1409
- vertices : TensorType ['b' , 'nv' , 3 , int ],
1410
- faces : TensorType ['b' , 'nf' , ' nvf', int ],
1411
- face_edges : TensorType ['b' , 'e' , 2 , int ] | None = None ,
1405
+ vertices : Int ['b nv 3' ],
1406
+ faces : Int ['b nf nvf' ],
1407
+ face_edges : Int ['b e 2' ] | None = None ,
1412
1408
codes : Tensor | None = None ,
1413
1409
cache : LayerIntermediates | None = None ,
1414
1410
** kwargs
0 commit comments