Skip to content

Commit 842f67b

Browse files
committed
sync with Threadripper
1 parent 69e42a2 commit 842f67b

File tree

3 files changed

+24
-21
lines changed

3 files changed

+24
-21
lines changed

mp_transformer/config.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
# "latent_dim": 32,
99
# "latent_dim": 48,
1010
# "latent_dim": 64,
11-
# "latent_dim": 128,
12-
"latent_dim": 256,
13-
#"num_primitives": 4,
11+
"latent_dim": 128,
12+
# "latent_dim": 256,
13+
# "num_primitives": 4,
1414
"num_primitives": 6,
1515
# "num_primitives": 8,
1616
# "hidden_dim": 16,
1717
# "hidden_dim": 32,
1818
# "hidden_dim": 48,
1919
# "hidden_dim": 64,
20-
# "hidden_dim": 128,
21-
"hidden_dim": 256,
20+
"hidden_dim": 128,
21+
# "hidden_dim": 256,
2222
"learn_segmentation": True,
2323
"masking_slope": 1,
2424
# "masking_slope": 0.75,
@@ -27,7 +27,9 @@
2727
# "kl_weight": 5e-3,
2828
# "kl_weight": 1e-2,
2929
# "kl_weight": 2e-2,
30-
"kl_weight": 1e-3,
30+
# "kl_weight": 5e-2,
31+
"kl_weight": 1e-1,
32+
# "kl_weight": 1e-3,
3133
# "kl_weight": 1e-4,
3234
"anneal_start": 10,
3335
# "anneal_start": 5,
@@ -41,7 +43,7 @@
4143
# "cycle_length": 100,
4244
# "cycle_length": 200,
4345
# "durations_weight": 1e-6,
44-
#"durations_weight": 1e-4,
46+
# "durations_weight": 1e-4,
4547
"durations_weight": 0,
4648
# "durations_weight": 1e-5,
4749
"lr": 1e-4,
@@ -53,10 +55,10 @@
5355
"N_val": 40000,
5456
# "N_val": 2,
5557
"sequence_length": 128,
56-
#"epochs": 200,
58+
# "epochs": 200,
5759
# "epochs": 250,
5860
# "epochs": 230,
59-
# "epochs": 3000,
61+
# "epochs": 2000,
6062
# "epochs": 1000,
6163
# "epochs": 300,
6264
"epochs": 500,
@@ -67,14 +69,18 @@
6769
# "run_name": "fresh-Transformer",
6870
# "run_name": "smol-Transformer",
6971
# "run_name": "resume-Transformer",
70-
#"run_name": "smol-Transformer",
71-
#"run_name": "notrafo-Transformer",
72+
# "run_name": "smol-Transformer",
73+
# "run_name": "notrafo-Transformer",
7274
# "run_name": "notrafo-sigmoid-Transformer",
7375
# "run_name": "noanneal-sigmoid-Transformer",
7476
# "run_name": "sigmoid-Transformer",
7577
# "run_name": "midKL-Transformer",
76-
"run_name": "lowKL-Transformer",
78+
# "run_name": "lowKL-Transformer",
79+
# "run_name": "highKL-Transformer",
80+
"run_name": "veryhighKL-Transformer",
81+
# "run_name": "noanneal-highKL-Transformer",
7782
# "run_name": "cyclical-Transformer",
83+
# "run_name": "cyclical-lowKL-Transformer",
7884
# "run_name": "nosigmoid-Transformer",
7985
# "run_name": "nosigmoid-Transformer",
8086
# "run_name": "relu-sigmoid-Transformer",

mp_transformer/datasets/toy_dataset.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from torch.utils.data import Dataset
1010
from torchvision import transforms
1111

12-
from mp_transformer.utils.generate_toydata import forward
13-
1412
PIL.PILLOW_VERSION = PIL.__version__ # torchvision bug
1513

1614

@@ -22,18 +20,17 @@ def normalize_pose(pose):
2220
# print(pose)
2321
# print(np.where(pose < -1))
2422
# print(np.where(pose > 1))
25-
# assert np.all(pose >= -1) and np.all(pose <= 1)
26-
# pose = (pose + 1) / 2
27-
# assert np.all(pose >= 0) and np.all(pose <= 1)
28-
29-
pose = forward(pose)
23+
assert np.all(pose >= -1) and np.all(pose <= 1)
24+
pose = (pose + 1) / 2
25+
assert np.all(pose >= 0) and np.all(pose <= 1)
26+
3027
return pose
3128

3229

3330
def unnormalize_pose(pose):
3431
"""Transform [0, 1] to [-1, 1]"""
3532
# assert np.all(pose >= 0) and np.all(pose <= 1)
36-
# pose = pose * 2 - 1
33+
pose = pose * 2 - 1
3734
# assert np.all(pose >= -1) and np.all(pose <= 1)
3835
return pose
3936

mp_transformer/utils/generate_toy_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,6 @@ def main(iterations=1, train_or_val="train", gen_images=True, N=5000):
241241
# ITERATIONS = 16
242242
ITERATIONS = 20
243243
TRAIN_OR_VAL = "both"
244-
#TRAIN_OR_VAL = "train"
244+
# TRAIN_OR_VAL = "train"
245245
# TRAIN_OR_VAL = "val"
246246
main(iterations=ITERATIONS, train_or_val=TRAIN_OR_VAL, gen_images=GEN_IMAGES, N=N)

0 commit comments

Comments
 (0)