-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtransforms.py
70 lines (65 loc) · 2.67 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torchvision import transforms
from torchvision.transforms._transforms_video import RandomCropVideo, RandomResizedCropVideo,CenterCropVideo, NormalizeVideo,ToTensorVideo,RandomHorizontalFlipVideo
import warnings
warnings.filterwarnings("ignore")
def init_transform_dict(input_res = 224,
center_crop = 256,
randcrop_scale=(0.5, 1.0),
color_jitter=(0, 0, 0),
norm_mean=(0.485, 0.456, 0.406),
norm_std=(0.229, 0.224, 0.225)):
normalize = transforms.Normalize(mean=norm_mean, std=norm_std)
tsfm_dict = {
'train': transforms.Compose([
transforms.RandomResizedCrop(input_res, scale=randcrop_scale),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]),
normalize,
]),
'val': transforms.Compose([
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
]),
'test': transforms.Compose([
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
])
}
return tsfm_dict
def init_video_transform_dict(input_res = 224,
center_crop = 256,
randcrop_scale=(0.5, 1.0),
color_jitter=(0, 0, 0),
norm_mean=(0.485, 0.456, 0.406),
norm_std=(0.229, 0.224, 0.225)):
print('Video Transform is used!')
normalize = NormalizeVideo(mean=norm_mean, std=norm_std)
tsfm_dict = {
'train': transforms.Compose([
RandomResizedCropVideo(input_res, scale=randcrop_scale),
RandomHorizontalFlipVideo(),
transforms.ColorJitter(brightness=color_jitter[0], saturation=color_jitter[1], hue=color_jitter[2]),
normalize,
]),
'val': transforms.Compose([
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
]),
'test': transforms.Compose([
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
])
}
return tsfm_dict