6
6
import os
7
7
8
8
def create_train_dataset (batch_size = 128 , root = '../data' ):
9
+ """
10
+ Create different training dataset
11
+ """
12
+
9
13
transform_train = transforms .Compose ([
10
14
transforms .ToTensor (),
11
15
])
@@ -53,9 +57,10 @@ def load_checkpoint(file_name, net = None, optimizer = None, lr_scheduler = None
53
57
print ("=> no checkpoint found at '{}'" .format (file_name ))
54
58
55
59
def make_symlink (source , link_name ):
56
- '''
60
+ """
57
61
Note: overwriting enabled!
58
- '''
62
+ """
63
+
59
64
if os .path .exists (link_name ):
60
65
print ("Link name already exist! Removing '{}' and overwriting" .format (link_name ))
61
66
os .remove (link_name )
@@ -87,3 +92,41 @@ def onehot_like(a, index, value=1):
87
92
x = np .zeros_like (a )
88
93
x [index ] = value
89
94
return x
95
+
96
+ def reduce_sum (x , keepdim = True ):
97
+ # silly PyTorch, when will you get proper reducing sums/means?
98
+ for a in reversed (range (1 , x .dim ())):
99
+ x = x .sum (a , keepdim = keepdim )
100
+ return x
101
+
102
+ def arctanh (x , eps = 1e-6 ):
103
+ """
104
+ Calculate arctanh(x)
105
+ """
106
+ x *= (1. - eps )
107
+ return (np .log ((1 + x ) / (1 - x ))) * 0.5
108
+
109
+ def l2r_dist (x , y , keepdim = True , eps = 1e-8 ):
110
+ d = (x - y )** 2
111
+ d = reduce_sum (d , keepdim = keepdim )
112
+ d += eps # to prevent infinite gradient at 0
113
+ return d .sqrt ()
114
+
115
+
116
+ def l2_dist (x , y , keepdim = True ):
117
+ d = (x - y )** 2
118
+ return reduce_sum (d , keepdim = keepdim )
119
+
120
+
121
+ def l1_dist (x , y , keepdim = True ):
122
+ d = torch .abs (x - y )
123
+ return reduce_sum (d , keepdim = keepdim )
124
+
125
+
126
+ def l2_norm (x , keepdim = True ):
127
+ norm = reduce_sum (x * x , keepdim = keepdim )
128
+ return norm .sqrt ()
129
+
130
+
131
+ def l1_norm (x , keepdim = True ):
132
+ return reduce_sum (x .abs (), keepdim = keepdim )
0 commit comments