Skip to content

Commit 56f0c58

Browse files
committed
allow for turning off typechecking for torch compile
1 parent 06b1190 commit 56f0c58

File tree

7 files changed

+117
-57
lines changed

7 files changed

+117
-57
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,14 @@ mesh_token_ids = autoencoder.tokenize(
132132
# (batch, num face vertices, residual quantized layer)
133133
```
134134

135+
## Typecheck
136+
137+
At the project root, run
138+
139+
```bash
140+
$ cp .env.sample .env
141+
```
142+
135143
## Todo
136144

137145
- [x] autoencoder

meshgpt_pytorch/data.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,19 @@
22

33
from pathlib import Path
44
from functools import partial
5+
56
import torch
67
from torch import Tensor
78
from torch import is_tensor
8-
import torch.nn.functional as F
99
from torch.utils.data import Dataset
1010
from torch.nn.utils.rnn import pad_sequence
1111

12-
import numpy as np
1312
from numpy.lib.format import open_memmap
1413

1514
from einops import rearrange, reduce
1615

1716
from beartype import beartype
18-
from beartype.typing import Tuple, List, Callable, Dict, Callable
17+
from beartype.typing import Tuple, List, Callable, Dict
1918

2019
from torchtyping import TensorType
2120

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
from torch.utils.checkpoint import checkpoint
1212
from torch.cuda.amp import autocast
1313

14-
from torchtyping import TensorType
15-
1614
from pytorch_custom_utils import save_load
1715

18-
from beartype import beartype
1916
from beartype.typing import Tuple, Callable, List, Dict, Any
17+
from meshgpt_pytorch.typing import Float, Int, Bool, typecheck
2018

2119
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
2220

@@ -26,7 +24,6 @@
2624
from einx import get_at
2725

2826
from x_transformers import Decoder
29-
from x_transformers.attend import Attend
3027
from x_transformers.x_transformers import RMSNorm, FeedForward, LayerIntermediates
3128

3229
from x_transformers.autoregressive_wrapper import (
@@ -78,8 +75,8 @@ def divisible_by(num, den):
7875
def is_odd(n):
7976
return not divisible_by(n, 2)
8077

81-
def is_empty(l):
82-
return len(l) == 0
78+
def is_empty(x):
79+
return len(x) == 0
8380

8481
def is_tensor_empty(t: Tensor):
8582
return t.numel() == 0
@@ -157,7 +154,7 @@ def derive_angle(x, y, eps = 1e-5):
157154

158155
@torch.no_grad()
159156
def get_derived_face_features(
160-
face_coords: TensorType['b', 'nf', 'nvf', 3, float] # 3 or 4 vertices with 3 coordinates
157+
face_coords: Float['b nf nvf 3'] # 3 or 4 vertices with 3 coordinates
161158
):
162159
shifted_face_coords = torch.cat((face_coords[:, :, -1:], face_coords[:, :, :-1]), dim = 2)
163160

@@ -178,7 +175,7 @@ def get_derived_face_features(
178175

179176
# tensor helper functions
180177

181-
@beartype
178+
@typecheck
182179
def discretize(
183180
t: Tensor,
184181
*,
@@ -194,7 +191,7 @@ def discretize(
194191

195192
return t.round().long().clamp(min = 0, max = num_discrete - 1)
196193

197-
@beartype
194+
@typecheck
198195
def undiscretize(
199196
t: Tensor,
200197
*,
@@ -210,7 +207,7 @@ def undiscretize(
210207
t /= num_discrete
211208
return t * (hi - lo) + lo
212209

213-
@beartype
210+
@typecheck
214211
def gaussian_blur_1d(
215212
t: Tensor,
216213
*,
@@ -234,7 +231,7 @@ def gaussian_blur_1d(
234231
out = F.conv1d(t, kernel, padding = half_width, groups = channels)
235232
return rearrange(out, 'b c n -> b n c')
236233

237-
@beartype
234+
@typecheck
238235
def scatter_mean(
239236
tgt: Tensor,
240237
indices: Tensor,
@@ -421,7 +418,7 @@ def forward(
421418

422419
@save_load(version = __version__)
423420
class MeshAutoencoder(Module):
424-
@beartype
421+
@typecheck
425422
def __init__(
426423
self,
427424
num_discrete_coors = 128,
@@ -671,15 +668,15 @@ def _from_pretrained(
671668
model.to(map_location)
672669
return model
673670

674-
@beartype
671+
@typecheck
675672
def encode(
676673
self,
677674
*,
678-
vertices: TensorType['b', 'nv', 3, float],
679-
faces: TensorType['b', 'nf', 'nvf', int],
680-
face_edges: TensorType['b', 'e', 2, int],
681-
face_mask: TensorType['b', 'nf', bool],
682-
face_edges_mask: TensorType['b', 'e', bool],
675+
vertices: Float['b nv 3'],
676+
faces: Int['b nf nvf'],
677+
face_edges: Int['b e 2'],
678+
face_mask: Bool['b nf'],
679+
face_edges_mask: Bool['b e'],
683680
return_face_coordinates = False
684681
):
685682
"""
@@ -692,7 +689,6 @@ def encode(
692689
d - embed dim
693690
"""
694691

695-
batch, num_vertices, num_coors, device = *vertices.shape, vertices.device
696692
_, num_faces, num_vertices_per_face = faces.shape
697693

698694
assert self.num_vertices_per_face == num_vertices_per_face
@@ -773,18 +769,18 @@ def encode(
773769

774770
return face_embed, discrete_face_coords
775771

776-
@beartype
772+
@typecheck
777773
def quantize(
778774
self,
779775
*,
780-
faces: TensorType['b', 'nf', 'nvf', int],
781-
face_mask: TensorType['b', 'n', bool],
782-
face_embed: TensorType['b', 'nf', 'd', float],
776+
faces: Int['b nf nvf'],
777+
face_mask: Bool['b n'],
778+
face_embed: Float['b nf d'],
783779
pad_id = None,
784780
rvq_sample_codebook_temp = 1.
785781
):
786782
pad_id = default(pad_id, self.pad_id)
787-
batch, num_faces, device = *faces.shape[:2], faces.device
783+
batch, device = faces.shape[0], faces.device
788784

789785
max_vertex_index = faces.amax()
790786
num_vertices = int(max_vertex_index.item() + 1)
@@ -858,11 +854,11 @@ def quantize_wrapper_fn(inp):
858854

859855
return face_embed_output, codes_output, commit_loss
860856

861-
@beartype
857+
@typecheck
862858
def decode(
863859
self,
864-
quantized: TensorType['b', 'n', 'd', float],
865-
face_mask: TensorType['b', 'n', bool]
860+
quantized: Float['b n d'],
861+
face_mask: Bool['b n']
866862
):
867863
conv_face_mask = rearrange(face_mask, 'b n -> b 1 n')
868864

@@ -884,12 +880,12 @@ def decode(
884880

885881
return rearrange(x, 'b d n -> b n d')
886882

887-
@beartype
883+
@typecheck
888884
@torch.no_grad()
889885
def decode_from_codes_to_faces(
890886
self,
891887
codes: Tensor,
892-
face_mask: TensorType['b', 'n', bool] | None = None,
888+
face_mask: Bool['b n'] | None = None,
893889
return_discrete_codes = False
894890
):
895891
codes = rearrange(codes, 'b ... -> b (...)')
@@ -964,13 +960,13 @@ def tokenize(self, vertices, faces, face_edges = None, **kwargs):
964960

965961
return codes
966962

967-
@beartype
963+
@typecheck
968964
def forward(
969965
self,
970966
*,
971-
vertices: TensorType['b', 'nv', 3, float],
972-
faces: TensorType['b', 'nf', 'nvf', int],
973-
face_edges: TensorType['b', 'e', 2, int] | None = None,
967+
vertices: Float['b nv 3'],
968+
faces: Int['b nf nvf'],
969+
face_edges: Int['b e 2'] | None = None,
974970
return_codes = False,
975971
return_loss_breakdown = False,
976972
return_recon_faces = False,
@@ -980,7 +976,7 @@ def forward(
980976
if not exists(face_edges):
981977
face_edges = derive_face_edges_from_faces(faces, pad_id = self.pad_id)
982978

983-
num_faces, num_face_edges, device = faces.shape[1], face_edges.shape[1], faces.device
979+
device = faces.device
984980

985981
face_mask = reduce(faces != self.pad_id, 'b nf c -> b nf', 'all')
986982
face_edges_mask = reduce(face_edges != self.pad_id, 'b e ij -> b e', 'all')
@@ -1079,7 +1075,7 @@ def forward(
10791075

10801076
@save_load(version = __version__)
10811077
class MeshTransformer(Module,PyTorchModelHubMixin):
1082-
@beartype
1078+
@typecheck
10831079
def __init__(
10841080
self,
10851081
autoencoder: MeshAutoencoder,
@@ -1270,7 +1266,7 @@ def _from_pretrained(
12701266
def device(self):
12711267
return next(self.parameters()).device
12721268

1273-
@beartype
1269+
@typecheck
12741270
@torch.no_grad()
12751271
def embed_texts(self, texts: str | List[str]):
12761272
single_text = not isinstance(texts, list)
@@ -1287,7 +1283,7 @@ def embed_texts(self, texts: str | List[str]):
12871283

12881284
@eval_decorator
12891285
@torch.no_grad()
1290-
@beartype
1286+
@typecheck
12911287
def generate(
12921288
self,
12931289
prompt: Tensor | None = None,
@@ -1406,9 +1402,9 @@ def generate(
14061402
def forward(
14071403
self,
14081404
*,
1409-
vertices: TensorType['b', 'nv', 3, int],
1410-
faces: TensorType['b', 'nf', 'nvf', int],
1411-
face_edges: TensorType['b', 'e', 2, int] | None = None,
1405+
vertices: Int['b nv 3'],
1406+
faces: Int['b nf nvf'],
1407+
face_edges: Int['b e 2'] | None = None,
14121408
codes: Tensor | None = None,
14131409
cache: LayerIntermediates | None = None,
14141410
**kwargs

meshgpt_pytorch/trainer.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
from pathlib import Path
22
from functools import partial
33
from packaging import version
4-
from contextlib import nullcontext, contextmanager
4+
from contextlib import nullcontext
55

66
import torch
7-
from torch import nn, Tensor
87
from torch.nn import Module
9-
import torch.nn.functional as F
108
from torch.utils.data import Dataset, DataLoader
119
from torch.optim.lr_scheduler import _LRScheduler
1210

@@ -19,9 +17,8 @@
1917
from accelerate import Accelerator
2018
from accelerate.utils import DistributedDataParallelKwargs
2119

22-
from beartype import beartype
23-
from beartype.door import is_bearable
2420
from beartype.typing import Tuple, Type, List
21+
from meshgpt_pytorch.typing import typecheck, beartype_isinstance
2522

2623
from ema_pytorch import EMA
2724

@@ -67,7 +64,7 @@ def maybe_del(d: dict, *keys):
6764

6865
@add_wandb_tracker_contextmanager()
6966
class MeshAutoencoderTrainer(Module):
70-
@beartype
67+
@typecheck
7168
def __init__(
7269
self,
7370
model: MeshAutoencoder,
@@ -81,7 +78,9 @@ def __init__(
8178
learning_rate: float = 1e-4,
8279
weight_decay: float = 0.,
8380
max_grad_norm: float | None = None,
84-
ema_kwargs: dict = dict(),
81+
ema_kwargs: dict = dict(
82+
use_foreach = True
83+
),
8584
scheduler: Type[_LRScheduler] | None = None,
8685
scheduler_kwargs: dict = dict(),
8786
accelerator_kwargs: dict = dict(),
@@ -147,7 +146,7 @@ def __init__(
147146
)
148147

149148
if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
150-
assert is_bearable(dataset.data_kwargs, List[str])
149+
assert beartype_isinstance(dataset.data_kwargs, List[str])
151150
self.data_kwargs = dataset.data_kwargs
152151
else:
153152
self.data_kwargs = data_kwargs
@@ -324,7 +323,7 @@ def forward(self):
324323

325324
@add_wandb_tracker_contextmanager()
326325
class MeshTransformerTrainer(Module):
327-
@beartype
326+
@typecheck
328327
def __init__(
329328
self,
330329
model: MeshTransformer,
@@ -407,7 +406,7 @@ def __init__(
407406
)
408407

409408
if hasattr(dataset, 'data_kwargs') and exists(dataset.data_kwargs):
410-
assert is_bearable(dataset.data_kwargs, List[str])
409+
assert beartype_isinstance(dataset.data_kwargs, List[str])
411410
self.data_kwargs = dataset.data_kwargs
412411
else:
413412
self.data_kwargs = data_kwargs

0 commit comments

Comments
 (0)