Skip to content

Commit 3c559f6

Browse files
author
Sehoon Kim
committed
Minor fix
1 parent 076e7a9 commit 3c559f6

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

evaluate.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import os
2-
import sys
3-
41
import torch
52
from torch.utils.data import DataLoader
63

@@ -44,10 +41,11 @@ def evaluate():
4441

4542
gts = []
4643
preds = []
44+
regularity_score = []
4745

4846
with torch.set_grad_enabled(False):
4947
pbar = tqdm(testloader)
50-
regularity_score = []
48+
5149
for i, seqs in enumerate(pbar):
5250
model.eval()
5351

@@ -65,12 +63,17 @@ def evaluate():
6563
gts.append(seqs)
6664
preds.append(outs)
6765

68-
seqs_reconstruction_cost = np.array([np.linalg.norm(np.subtract(gts[i],preds[i])) for i in range(0,len(pbar))])
69-
sa = (seqs_reconstruction_cost - np.min(seqs_reconstruction_cost)) / np.max(seqs_reconstruction_cost)
70-
sr = 1 - sa
66+
seqs_reconstruction_cost = np.array([np.linalg.norm(np.subtract(gts[j],preds[j])) for j in range(0,i+1)])
67+
sa = (seqs_reconstruction_cost - np.min(seqs_reconstruction_cost)) / np.max(seqs_reconstruction_cost)
68+
sr = 1 - sa
69+
70+
if i == 0:
71+
regularity_score.extend(sr)
72+
else:
73+
regularity_score.append(sr[-1])
7174

7275
f = open('result.csv','w')
73-
for i, score in enumerate(sr):
76+
for i, score in enumerate(regularity_score):
7477
vstr = str(i) + ',' + str(score) + '\n'
7578
f.write(vstr)
7679
f.close()

results/Test001.png

-895 Bytes
Loading

train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import sys
32

43
import torch
54
from torch.utils.data import DataLoader

0 commit comments

Comments
 (0)