File tree Expand file tree Collapse file tree 3 files changed +11
-9
lines changed Expand file tree Collapse file tree 3 files changed +11
-9
lines changed Original file line number Diff line number Diff line change 1
- import os
2
- import sys
3
-
4
1
import torch
5
2
from torch .utils .data import DataLoader
6
3
@@ -44,10 +41,11 @@ def evaluate():
44
41
45
42
gts = []
46
43
preds = []
44
+ regularity_score = []
47
45
48
46
with torch .set_grad_enabled (False ):
49
47
pbar = tqdm (testloader )
50
- regularity_score = []
48
+
51
49
for i , seqs in enumerate (pbar ):
52
50
model .eval ()
53
51
@@ -65,12 +63,17 @@ def evaluate():
65
63
gts .append (seqs )
66
64
preds .append (outs )
67
65
68
- seqs_reconstruction_cost = np .array ([np .linalg .norm (np .subtract (gts [i ],preds [i ])) for i in range (0 ,len (pbar ))])
69
- sa = (seqs_reconstruction_cost - np .min (seqs_reconstruction_cost )) / np .max (seqs_reconstruction_cost )
70
- sr = 1 - sa
66
+ seqs_reconstruction_cost = np .array ([np .linalg .norm (np .subtract (gts [j ],preds [j ])) for j in range (0 ,i + 1 )])
67
+ sa = (seqs_reconstruction_cost - np .min (seqs_reconstruction_cost )) / np .max (seqs_reconstruction_cost )
68
+ sr = 1 - sa
69
+
70
+ if i == 0 :
71
+ regularity_score .extend (sr )
72
+ else :
73
+ regularity_score .append (sr [- 1 ])
71
74
72
75
f = open ('result.csv' ,'w' )
73
- for i , score in enumerate (sr ):
76
+ for i , score in enumerate (regularity_score ):
74
77
vstr = str (i ) + ',' + str (score ) + '\n '
75
78
f .write (vstr )
76
79
f .close ()
Original file line number Diff line number Diff line change 1
1
import os
2
- import sys
3
2
4
3
import torch
5
4
from torch .utils .data import DataLoader
You can’t perform that action at this time.
0 commit comments