Skip to content

Commit 834cecd

Browse files
committed
add rgb std for datatransform
1 parent 19d4386 commit 834cecd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

data/data_augment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,14 +228,16 @@ class BaseTransform(object):
228228
resize (int): input dimension to SSD
229229
rgb_means ((int,int,int)): average RGB of the dataset
230230
(104,117,123)
231+
rgb_std: std of the dataset
231232
swap ((int,int,int)): final order of channels
232233
Returns:
233234
transform (transform) : callable transform to be applied to test/val
234235
data
235236
"""
236-
def __init__(self, resize, rgb_means, swap=(2, 0, 1)):
237+
def __init__(self, resize, rgb_means,rgb_std = (1,1,1), swap=(2, 0, 1)):
237238
self.means = rgb_means
238239
self.resize = resize
240+
self.std = rgb_std
239241
self.swap = swap
240242

241243
# assume input is cv2 img for now
@@ -246,5 +248,6 @@ def __call__(self, img):
246248
img = cv2.resize(np.array(img), (self.resize,
247249
self.resize),interpolation = interp_method).astype(np.float32)
248250
img -= self.means
251+
img /= self.std
249252
img = img.transpose(self.swap)
250253
return torch.from_numpy(img)

0 commit comments

Comments
 (0)