Skip to content

Commit 3a2bd11

Browse files
committed
glimpse of hope
1 parent 842f67b commit 3a2bd11

File tree

4 files changed

+71
-27
lines changed

4 files changed

+71
-27
lines changed

mp_transformer/config.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -3,65 +3,80 @@
33
"pose_dim": 6,
44
"num_attention_heads": 4,
55
"num_transformer_layers": 4,
6+
# "num_transformer_layers": 2,
7+
# "latent_dim": 4,
8+
# "latent_dim": 6,
69
# "latent_dim": 8,
10+
"latent_dim": 12,
711
# "latent_dim": 16,
812
# "latent_dim": 32,
913
# "latent_dim": 48,
1014
# "latent_dim": 64,
11-
"latent_dim": 128,
15+
# "latent_dim": 128,
1216
# "latent_dim": 256,
1317
# "num_primitives": 4,
1418
"num_primitives": 6,
1519
# "num_primitives": 8,
20+
# "hidden_dim": 4,
21+
# "hidden_dim": 6,
22+
# "hidden_dim": 8,
1623
# "hidden_dim": 16,
1724
# "hidden_dim": 32,
1825
# "hidden_dim": 48,
26+
# "hidden_dim": 56,
1927
# "hidden_dim": 64,
28+
# "hidden_dim": 80,
29+
# "hidden_dim": 96,
2030
"hidden_dim": 128,
2131
# "hidden_dim": 256,
2232
"learn_segmentation": True,
2333
"masking_slope": 1,
2434
# "masking_slope": 0.75,
2535
# "masking_slope": 0.5,
36+
# "kl_weight": 1e-6,
2637
# "kl_weight": 1e-5,
2738
# "kl_weight": 5e-3,
2839
# "kl_weight": 1e-2,
2940
# "kl_weight": 2e-2,
3041
# "kl_weight": 5e-2,
31-
"kl_weight": 1e-1,
42+
# "kl_weight": 1e-1,
3243
# "kl_weight": 1e-3,
33-
# "kl_weight": 1e-4,
34-
"anneal_start": 10,
44+
"kl_weight": 1e-4,
45+
# "anneal_start": 10,
46+
# "anneal_start": 199,
3547
# "anneal_start": 5,
36-
# "anneal_start": 0,
37-
"anneal_end": 50,
48+
"anneal_start": 0,
49+
# "anneal_end": 50,
50+
# "anneal_end": 100,
51+
# "anneal_end": 199,
3852
# "anneal_end": 30,
3953
# "anneal_end": 20,
4054
# "anneal_end": 15,
41-
# "anneal_end": 0,
55+
"anneal_end": 0,
4256
"cycle_length": None,
4357
# "cycle_length": 100,
4458
# "cycle_length": 200,
45-
# "durations_weight": 1e-6,
59+
"durations_weight": 1e-6,
4660
# "durations_weight": 1e-4,
47-
"durations_weight": 0,
61+
# "durations_weight": 0,
4862
# "durations_weight": 1e-5,
49-
"lr": 1e-4,
50-
# "lr": 5e-4,
63+
# "lr": 1e-4,
64+
"lr": 2e-4,
65+
# "lr": 4e-4,
5166
"batch_size": 8,
5267
# "batch_size": 16,
5368
"N_train": 200000,
5469
# "N_train": 2,
5570
"N_val": 40000,
5671
# "N_val": 2,
5772
"sequence_length": 128,
58-
# "epochs": 200,
73+
"epochs": 200,
5974
# "epochs": 250,
6075
# "epochs": 230,
6176
# "epochs": 2000,
6277
# "epochs": 1000,
6378
# "epochs": 300,
64-
"epochs": 500,
79+
# "epochs": 500,
6580
# "epochs": 800,
6681
# "epochs": 400,
6782
# "epochs": 5,
@@ -77,13 +92,22 @@
7792
# "run_name": "midKL-Transformer",
7893
# "run_name": "lowKL-Transformer",
7994
# "run_name": "highKL-Transformer",
80-
"run_name": "veryhighKL-Transformer",
95+
# "run_name": "veryhighKL-Transformer",
96+
# "run_name": "tiny-Transformer",
97+
# "run_name": "noKLmedium-Transformer",
98+
# "run_name": "slowanneal-Transformer",
99+
# "run_name": "tinyKL-Transformer",
81100
# "run_name": "noanneal-highKL-Transformer",
82101
# "run_name": "cyclical-Transformer",
83102
# "run_name": "cyclical-lowKL-Transformer",
84103
# "run_name": "nosigmoid-Transformer",
85104
# "run_name": "nosigmoid-Transformer",
86105
# "run_name": "relu-sigmoid-Transformer",
106+
# "run_name": "bottleneck-Transformer",
107+
# "run_name": "more-bottleneck",
108+
# "run_name": "lowKL-most-bottleneck",
109+
"run_name": "KL-most-bottleneck",
110+
# "run_name": "short-more-bottleneck",
87111
}
88112

89113
# for hyperparameter tuning with wandb sweep

mp_transformer/models/decoder.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ def __init__(self, config):
4040
# Decodes latent primtives and timestamps into subsequences of poses
4141
self.decoder = nn.Sequential(
4242
# self.feat_time: time feature dimension
43-
nn.Linear(self.latent_dim + 2 * self.feat_time, self.hidden_dim),
43+
# nn.Linear(self.latent_dim + 2 * self.feat_time, self.hidden_dim),
44+
nn.Linear(self.latent_dim + 2 * self.feat_time, self.hidden_dim // 2),
4445
nn.ReLU(),
45-
nn.LayerNorm(self.hidden_dim),
46-
nn.Linear(self.hidden_dim, self.hidden_dim),
46+
# nn.LayerNorm(self.hidden_dim),
47+
# nn.Linear(self.hidden_dim, self.hidden_dim),
48+
nn.LayerNorm(self.hidden_dim // 2),
49+
nn.Linear(self.hidden_dim // 2, self.hidden_dim),
4750
nn.ReLU(),
4851
nn.LayerNorm(self.hidden_dim),
4952
nn.Linear(self.hidden_dim, self.pose_dim),

mp_transformer/models/encoder.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class PositionalEncodingLayer(pl.LightningModule):
1212

1313
def __init__(self, config):
1414
super().__init__()
15-
self.latent_dim = config["latent_dim"]
15+
# self.latent_dim = config["latent_dim"]
16+
self.latent_dim = config["hidden_dim"]
1617

1718
# sinusoidal frequencies for positional encoding
1819
# linearly spaced 1D tensor ranging from 0 to self.latent_dim a size of self.latent_dim // 2
@@ -74,6 +75,7 @@ def __init__(self, config):
7475
# self.save_hyperparameters(config) # PyTorch Lightning
7576

7677
self.pose_dim = config["pose_dim"]
78+
self.hidden_dim = config["hidden_dim"]
7779
self.num_primitives = config["num_primitives"]
7880
self.latent_dim = config["latent_dim"]
7981
self.num_attention_heads = config["num_attention_heads"]
@@ -82,12 +84,20 @@ def __init__(self, config):
8284

8385
self.positional_encoding = PositionalEncodingLayer(config)
8486

87+
# encoder_layer = nn.TransformerEncoderLayer(
88+
# d_model=self.latent_dim,
89+
# nhead=self.num_attention_heads,
90+
# )
91+
# decoder_layer = nn.TransformerDecoderLayer(
92+
# d_model=self.latent_dim,
93+
# nhead=self.num_attention_heads,
94+
# )
8595
encoder_layer = nn.TransformerEncoderLayer(
86-
d_model=self.latent_dim,
96+
d_model=self.hidden_dim,
8797
nhead=self.num_attention_heads,
8898
)
8999
decoder_layer = nn.TransformerDecoderLayer(
90-
d_model=self.latent_dim,
100+
d_model=self.hidden_dim,
91101
nhead=self.num_attention_heads,
92102
)
93103
self.encoder_segments = torch.nn.TransformerEncoder(
@@ -97,23 +107,30 @@ def __init__(self, config):
97107
decoder_layer=decoder_layer, num_layers=self.num_transformer_layers
98108
)
99109

100-
self.embedding = nn.Linear(self.pose_dim, self.latent_dim)
110+
self.embedding = nn.Linear(self.pose_dim, self.hidden_dim)
101111

102112
self.mean_encoder = nn.Sequential(
103-
nn.Linear(self.latent_dim, self.latent_dim),
113+
# nn.Linear(self.hidden_dim, self.hidden_dim),
114+
nn.Linear(self.hidden_dim, self.hidden_dim // 2),
104115
nn.ReLU(),
105-
nn.Linear(self.latent_dim, self.latent_dim),
116+
# nn.Linear(self.hidden_dim, self.latent_dim),
117+
nn.Linear(self.hidden_dim // 2, self.latent_dim),
106118
)
107119
self.logvar_encoder = nn.Sequential(
108-
nn.Linear(self.latent_dim, self.latent_dim),
120+
# nn.Linear(self.hidden_dim, self.hidden_dim),
121+
nn.Linear(self.hidden_dim, self.hidden_dim // 2),
109122
nn.ReLU(),
110-
nn.Linear(self.latent_dim, self.latent_dim),
123+
# nn.Linear(self.hidden_dim, self.latent_dim),
124+
nn.Linear(self.hidden_dim // 2, self.latent_dim),
111125
)
112126

113127
# positional encoding used as input for the transformer decoder
114128
# TODO: keep or move?
129+
# self.initial_encoding = self.get_positional_encoding(
130+
# self.num_primitives, self.latent_dim
131+
# )
115132
self.initial_encoding = self.get_positional_encoding(
116-
self.num_primitives, self.latent_dim
133+
self.num_primitives, self.hidden_dim
117134
)
118135

119136
# TODO: use PositionalEncodingLayer instead?

unittests/test_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_main(self):
1313
minimal_config = {
1414
"latent_dim": 4,
1515
"num_primitives": 2,
16-
"hidden_dim": 2,
16+
"hidden_dim": 4,
1717
"batch_size": 3,
1818
"sequence_length": 5,
1919
"N_train": 4,

0 commit comments

Comments
 (0)