Skip to content

Commit 5cc8777

Browse files
Added input normalization
1 parent 328f3c6 commit 5cc8777

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

car_dataset.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from warnings import warn
55

66
import numpy as np
7+
import torch
78
import torch.utils.data as data
89
from PIL import Image
910
from sklearn.model_selection import train_test_split
10-
from torch.utils.data import Subset, Dataset
11+
from torch.utils.data import Subset, Dataset, DataLoader
1112

1213

1314
def train_val_datasets(dataset: Dataset, val_split: float = 0.5, shuffle: bool = True) -> Tuple[Dataset, Dataset]:
@@ -234,3 +235,23 @@ def statistics(self) -> str:
234235
fmt_str += "Avg Height: {}\n".format(sum([img.height for img, gt in self]) / float(len(self)))
235236
fmt_str += "Avg Aspect: {}\n".format(sum([img.width / img.height for img, gt in self]) / float(len(self)))
236237
return fmt_str
238+
239+
def mean_and_std(self) -> Tuple[float, float]:
240+
loader = DataLoader(
241+
self.subsets['train'],
242+
batch_size=10,
243+
num_workers=1,
244+
shuffle=False
245+
)
246+
mean = torch.full((3,), 0.0)
247+
std = torch.full((3,), 0.0)
248+
nb_samples = 0.
249+
for data, gt in loader:
250+
batch_samples = data.size(0)
251+
data = data.view(batch_samples, data.size(1), -1)
252+
mean += data.mean(2).sum(0)
253+
std += data.std(2).sum(0)
254+
nb_samples += batch_samples
255+
mean /= nb_samples
256+
std /= nb_samples
257+
return mean, std

config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
data : "ORAND-CAR-2014/CAR-A/"
2-
train_val_split : 0.008
3-
epochs : 5
4-
batch_size : 8
2+
train_val_split: 0.8
3+
epochs: 1000
4+
batch_size: 16
55
lr : 1.0e-4
6-
verbose : True
6+
verbose: False
77
log: train_log.txt

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def create_dataloader(args: Namespace, verbose: bool = False) -> Dict[str, DataL
7777
'train': transforms.Compose([
7878
transforms.Resize((width, height)),
7979
transforms.ToTensor(),
80+
transforms.Normalize([0.6205, 0.6205, 0.6205], [0.1343, 0.1343, 0.1343])
8081
]),
8182
'test': transforms.Compose([
8283
transforms.Resize((width, height)),

0 commit comments

Comments
 (0)