Skip to content

Commit ac52ddf

Browse files
committed
fix MPJPE loss
1 parent 26903fb commit ac52ddf

10 files changed

+632
-310
lines changed

demo/demo.ipynb

+578-27
Large diffs are not rendered by default.

demo/load_from_wandb.ipynb

+21-256
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": null,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -17,17 +17,9 @@
1717
},
1818
{
1919
"cell_type": "code",
20-
"execution_count": 2,
20+
"execution_count": null,
2121
"metadata": {},
22-
"outputs": [
23-
{
24-
"name": "stdout",
25-
"output_type": "stream",
26-
"text": [
27-
"/data/daniel/git/mp-transformer\n"
28-
]
29-
}
30-
],
22+
"outputs": [],
3123
"source": [
3224
"current_dir = Path.cwd().parts[-1]\n",
3325
"if current_dir == \"demo\":\n",
@@ -37,98 +29,9 @@
3729
},
3830
{
3931
"cell_type": "code",
40-
"execution_count": 3,
32+
"execution_count": null,
4133
"metadata": {},
42-
"outputs": [
43-
{
44-
"name": "stderr",
45-
"output_type": "stream",
46-
"text": [
47-
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
48-
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mdaniel-a\u001b[0m (\u001b[33mtcs-mr\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
49-
]
50-
},
51-
{
52-
"data": {
53-
"text/html": [
54-
"wandb version 0.15.10 is available! To upgrade, please run:\n",
55-
" $ pip install wandb --upgrade"
56-
],
57-
"text/plain": [
58-
"<IPython.core.display.HTML object>"
59-
]
60-
},
61-
"metadata": {},
62-
"output_type": "display_data"
63-
},
64-
{
65-
"data": {
66-
"text/html": [
67-
"Tracking run with wandb version 0.15.4"
68-
],
69-
"text/plain": [
70-
"<IPython.core.display.HTML object>"
71-
]
72-
},
73-
"metadata": {},
74-
"output_type": "display_data"
75-
},
76-
{
77-
"data": {
78-
"text/html": [
79-
"Run data is saved locally in <code>/data/daniel/git/mp-transformer/wandb/run-20230914_181329-xlh6vk92</code>"
80-
],
81-
"text/plain": [
82-
"<IPython.core.display.HTML object>"
83-
]
84-
},
85-
"metadata": {},
86-
"output_type": "display_data"
87-
},
88-
{
89-
"data": {
90-
"text/html": [
91-
"Syncing run <strong><a href='https://wandb.ai/tcs-mr/mp-transformer/runs/xlh6vk92' target=\"_blank\">peach-cloud-545</a></strong> to <a href='https://wandb.ai/tcs-mr/mp-transformer' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
92-
],
93-
"text/plain": [
94-
"<IPython.core.display.HTML object>"
95-
]
96-
},
97-
"metadata": {},
98-
"output_type": "display_data"
99-
},
100-
{
101-
"data": {
102-
"text/html": [
103-
" View project at <a href='https://wandb.ai/tcs-mr/mp-transformer' target=\"_blank\">https://wandb.ai/tcs-mr/mp-transformer</a>"
104-
],
105-
"text/plain": [
106-
"<IPython.core.display.HTML object>"
107-
]
108-
},
109-
"metadata": {},
110-
"output_type": "display_data"
111-
},
112-
{
113-
"data": {
114-
"text/html": [
115-
" View run at <a href='https://wandb.ai/tcs-mr/mp-transformer/runs/xlh6vk92' target=\"_blank\">https://wandb.ai/tcs-mr/mp-transformer/runs/xlh6vk92</a>"
116-
],
117-
"text/plain": [
118-
"<IPython.core.display.HTML object>"
119-
]
120-
},
121-
"metadata": {},
122-
"output_type": "display_data"
123-
},
124-
{
125-
"name": "stderr",
126-
"output_type": "stream",
127-
"text": [
128-
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n"
129-
]
130-
}
131-
],
34+
"outputs": [],
13235
"source": [
13336
"run = wandb.init(project=\"mp-transformer\")\n",
13437
"artifact = run.use_artifact(\"tcs-mr/mp-transformer/model:v300\", type='model')\n",
@@ -137,63 +40,18 @@
13740
},
13841
{
13942
"cell_type": "code",
140-
"execution_count": 4,
43+
"execution_count": null,
14144
"metadata": {},
142-
"outputs": [
143-
{
144-
"name": "stdout",
145-
"output_type": "stream",
146-
"text": [
147-
"./artifacts/model:v300\n"
148-
]
149-
}
150-
],
45+
"outputs": [],
15146
"source": [
15247
"print(artifact_dir)"
15348
]
15449
},
15550
{
15651
"cell_type": "code",
157-
"execution_count": 7,
52+
"execution_count": null,
15853
"metadata": {},
159-
"outputs": [
160-
{
161-
"data": {
162-
"text/html": [
163-
"Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
164-
],
165-
"text/plain": [
166-
"<IPython.core.display.HTML object>"
167-
]
168-
},
169-
"metadata": {},
170-
"output_type": "display_data"
171-
},
172-
{
173-
"data": {
174-
"text/html": [
175-
" View run <strong style=\"color:#cdcd00\">peach-cloud-545</strong> at: <a href='https://wandb.ai/tcs-mr/mp-transformer/runs/xlh6vk92' target=\"_blank\">https://wandb.ai/tcs-mr/mp-transformer/runs/xlh6vk92</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
176-
],
177-
"text/plain": [
178-
"<IPython.core.display.HTML object>"
179-
]
180-
},
181-
"metadata": {},
182-
"output_type": "display_data"
183-
},
184-
{
185-
"data": {
186-
"text/html": [
187-
"Find logs at: <code>./wandb/run-20230914_181329-xlh6vk92/logs</code>"
188-
],
189-
"text/plain": [
190-
"<IPython.core.display.HTML object>"
191-
]
192-
},
193-
"metadata": {},
194-
"output_type": "display_data"
195-
}
196-
],
54+
"outputs": [],
19755
"source": [
19856
"CONFIG[\"hidden_dim\"] = 40\n",
19957
"CONFIG[\"latent_dim\"] = 48\n",
@@ -208,17 +66,9 @@
20866
},
20967
{
21068
"cell_type": "code",
211-
"execution_count": 6,
69+
"execution_count": null,
21270
"metadata": {},
213-
"outputs": [
214-
{
215-
"name": "stdout",
216-
"output_type": "stream",
217-
"text": [
218-
"Video saved to tmp/comp_vid.mp4\n"
219-
]
220-
}
221-
],
71+
"outputs": [],
22272
"source": [
22373
"item = val_dataset[-1]\n",
22474
"# item = val_dataset[64]\n",
@@ -229,26 +79,9 @@
22979
},
23080
{
23181
"cell_type": "code",
232-
"execution_count": 7,
82+
"execution_count": null,
23383
"metadata": {},
234-
"outputs": [
235-
{
236-
"data": {
237-
"text/html": [
238-
"\n",
239-
"<video width=\"320\" height=\"240\" controls>\n",
240-
" <source src=\"../tmp/comp_vid.mp4\" type=\"video/mp4\">\n",
241-
"</video>\n"
242-
],
243-
"text/plain": [
244-
"<IPython.core.display.HTML object>"
245-
]
246-
},
247-
"execution_count": 7,
248-
"metadata": {},
249-
"output_type": "execute_result"
250-
}
251-
],
84+
"outputs": [],
25285
"source": [
25386
"\n",
25487
"HTML(\"\"\"\n",
@@ -260,53 +93,19 @@
26093
},
26194
{
26295
"cell_type": "code",
263-
"execution_count": 8,
96+
"execution_count": null,
26497
"metadata": {},
265-
"outputs": [
266-
{
267-
"name": "stdout",
268-
"output_type": "stream",
269-
"text": [
270-
"poses range: [3.5189839309168747e-06, 0.9999984502792358]\n",
271-
"mus range: [-1.7213106155395508, 1.4266711473464966]\n",
272-
"average mu: -0.047322604805231094\n",
273-
"logvars range: [-8.529964447021484, -5.8591179847717285]\n",
274-
"median logvar: -7.437717914581299\n",
275-
"gt_latents range: [-1.848239541053772, 2.237262487411499]\n",
276-
"average gt_latents: 0.24795043468475342\n",
277-
"random_latents range: [-1.7329806089401245, 1.4522238969802856]\n",
278-
"average random_latents: -0.05122515186667442\n",
279-
"Video saved to tmp/fill_vid.mp4\n"
280-
]
281-
}
282-
],
98+
"outputs": [],
28399
"source": [
284100
"item = val_dataset[50]\n",
285101
"save_side_by_side_video(item, model, from_idx=1, to_idx=4, path=\"tmp/fill_vid.mp4\")"
286102
]
287103
},
288104
{
289105
"cell_type": "code",
290-
"execution_count": 9,
106+
"execution_count": null,
291107
"metadata": {},
292-
"outputs": [
293-
{
294-
"data": {
295-
"text/html": [
296-
"\n",
297-
"<video width=\"320\" height=\"240\" controls>\n",
298-
" <source src=\"../tmp/fill_vid.mp4\" type=\"video/mp4\">\n",
299-
"</video>\n"
300-
],
301-
"text/plain": [
302-
"<IPython.core.display.HTML object>"
303-
]
304-
},
305-
"execution_count": 9,
306-
"metadata": {},
307-
"output_type": "execute_result"
308-
}
309-
],
108+
"outputs": [],
310109
"source": [
311110
"\n",
312111
"\n",
@@ -319,52 +118,18 @@
319118
},
320119
{
321120
"cell_type": "code",
322-
"execution_count": 10,
121+
"execution_count": null,
323122
"metadata": {},
324-
"outputs": [
325-
{
326-
"name": "stdout",
327-
"output_type": "stream",
328-
"text": [
329-
"poses range: [3.5189839309168747e-06, 0.9999984502792358]\n",
330-
"mus range: [-1.487624168395996, 1.8445534706115723]\n",
331-
"average mu: 0.04888433218002319\n",
332-
"logvars range: [-8.518484115600586, -5.8925089836120605]\n",
333-
"median logvar: -7.370262622833252\n",
334-
"gt_latents range: [-1.948813796043396, 2.2732694149017334]\n",
335-
"average gt_latents: 0.24087952077388763\n",
336-
"random_latents range: [-1.4572112560272217, 1.8878892660140991]\n",
337-
"average random_latents: 0.04571348428726196\n",
338-
"Video saved to tmp/gen_vid.mp4\n"
339-
]
340-
}
341-
],
123+
"outputs": [],
342124
"source": [
343125
"save_side_by_side_video(item, model, from_idx=0, path=\"tmp/gen_vid.mp4\")"
344126
]
345127
},
346128
{
347129
"cell_type": "code",
348-
"execution_count": 11,
130+
"execution_count": null,
349131
"metadata": {},
350-
"outputs": [
351-
{
352-
"data": {
353-
"text/html": [
354-
"\n",
355-
"<video width=\"320\" height=\"240\" controls>\n",
356-
" <source src=\"../tmp/gen_vid.mp4\" type=\"video/mp4\">\n",
357-
"</video>\n"
358-
],
359-
"text/plain": [
360-
"<IPython.core.display.HTML object>"
361-
]
362-
},
363-
"execution_count": 11,
364-
"metadata": {},
365-
"output_type": "execute_result"
366-
}
367-
],
132+
"outputs": [],
368133
"source": [
369134
"\n",
370135
"\n",
2.34 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"model": [0.12564842402935028, 0.08892914652824402, 0.0831979289650917, 0.08228670805692673, 0.08564914762973785, 0.0953526571393013, 0.1093117892742157, 0.13006433844566345, 0.15024712681770325, 0.17235219478607178, 0.19508112967014313, 0.21641355752944946, 0.23918820917606354, 0.2606200873851776, 0.27891215682029724, 0.29907020926475525, 0.3144414722919464, 0.33412039279937744, 0.3491173982620239, 0.36269068717956543], "model_single_layer_transformer": [0.13918371498584747, 0.09241282939910889, 0.08359085023403168, 0.08277571201324463, 0.08699499070644379, 0.09778112918138504, 0.11320507526397705, 0.13444054126739502, 0.155262291431427, 0.1772998869419098, 0.20199641585350037, 0.22535252571105957, 0.24829532206058502, 0.26870790123939514, 0.28866785764694214, 0.3094658851623535, 0.3262939751148224, 0.3449738025665283, 0.3587801158428192, 0.3760082721710205], "model_vae": [0.30333247780799866, 0.3200998306274414, 0.33794063329696655, 0.3640925884246826, 0.3904769718647003, 0.4185958206653595, 0.44002050161361694, 0.46187642216682434, 0.4782748222351074, 0.49203750491142273, 0.5079672932624817, 0.519733726978302, 0.5377718806266785, 0.541067361831665, 0.5547375679016113, 0.5591103434562683, 0.5664754509925842, 0.5756450891494751, 0.5771839618682861, 0.5854555368423462], "model_vae_single_layer_transformer": [0.3151271939277649, 0.31420376896858215, 0.3264966309070587, 0.35072073340415955, 0.3787328600883484, 0.40926486253738403, 0.4301280379295349, 0.4536152184009552, 0.4696636199951172, 0.4861705005168915, 0.5050753355026245, 0.5159381031990051, 0.5327478051185608, 0.5394688844680786, 0.5467861890792847, 0.5526284575462341, 0.5601837635040283, 0.5701719522476196, 0.5759778022766113, 0.5826566815376282]}
Loading

demo/mpjpe_comparison.json

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"model": [21.98292350769043, 14.340300559997559, 13.43407917022705, 13.281380653381348, 13.68807315826416, 14.619558334350586, 16.02347755432129, 18.08142852783203, 20.648395538330078, 23.337923049926758, 26.42051887512207, 29.621652603149414, 33.34819412231445, 37.707115173339844, 41.067047119140625, 45.00002670288086, 48.33279800415039, 52.669677734375, 55.95282745361328, 58.46841812133789], "model_single_layer_transformer": [21.327821731567383, 16.639490127563477, 14.898528099060059, 14.581068992614746, 14.713642120361328, 15.887846946716309, 17.2874813079834, 19.435869216918945, 22.553186416625977, 25.458030700683594, 29.034015655517578, 32.70307540893555, 36.36600112915039, 40.38428497314453, 45.23580551147461, 48.895408630371094, 52.23849868774414, 55.964927673339844, 59.8380012512207, 62.868343353271484], "model_vae": [39.20149230957031, 52.6158332824707, 55.64812088012695, 59.33192443847656, 63.002384185791016, 68.69374084472656, 72.3512954711914, 76.72998046875, 80.79139709472656, 85.02607727050781, 86.35516357421875, 91.21097564697266, 94.64981842041016, 95.75718688964844, 99.6708755493164, 100.07467651367188, 102.95051574707031, 105.49016571044922, 103.9454345703125, 105.8687973022461], "model_vae_single_layer_transformer": [53.5693359375, 50.3300666809082, 51.67551803588867, 56.25404739379883, 61.61893844604492, 67.36402893066406, 72.91018676757812, 77.97484588623047, 82.93614196777344, 87.12987518310547, 90.387939453125, 94.40081024169922, 97.01863861083984, 99.4786605834961, 102.1368637084961, 102.08549499511719, 104.48709106445312, 106.8950424194336, 107.47530364990234, 109.40076446533203]}

demo/mpjpe_comparison.png

55.4 KB
Loading

mp_transformer/datasets/toy_dataset.py

-6
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,4 @@ def _get_segment(self, idx):
152152
images = torch.stack(images)
153153
ret["images"] = images
154154

155-
# Check for sudden jumps in the values of the poses tensor
156-
diff = poses[1:] - poses[:-1]
157-
max_jump = 0.25
158-
if torch.any(torch.abs(diff) > max_jump):
159-
print(f"{diff.abs().max()=} at {idx=}")
160-
161155
return ret

0 commit comments

Comments
 (0)