diff --git a/configs/resa/resa18_tusimple.py b/configs/resa/resa18_tusimple.py index 2c937d3..781e6d3 100644 --- a/configs/resa/resa18_tusimple.py +++ b/configs/resa/resa18_tusimple.py @@ -11,9 +11,10 @@ ) featuremap_out_channel = 128 featuremap_out_stride = 8 -num_classes = 3 +num_classes = 7 num_lanes = 6 + 1 classification = True +autocast = True aggregator = dict( type='RESA', @@ -29,7 +30,7 @@ decoder=dict(type='BUSD'), thr=0.6, sample_y=sample_y, - cat_dim = (num_lanes - 1, num_classes) + cat_dim = (num_classes, num_lanes - 1) ) optimizer = dict( @@ -40,9 +41,10 @@ ) -epochs = 15 -batch_size = 4 -total_iter = (3216 // batch_size + 1) * epochs +epochs = 25 +batch_size = 16 +total_training_samples = 3626 +total_iter = (total_training_samples // batch_size + 1) * epochs import math scheduler = dict( type = 'LambdaLR', @@ -110,6 +112,5 @@ log_interval = 200 eval_ep = 1 save_ep = epochs -#test_json_file='data/tusimple/label_data_0531_small.json' -test_json_file='data/tusimple/label_data_0531.json' +test_json_file='data/tusimple/test_label.json' lr_update_by_epoch = False \ No newline at end of file diff --git a/configs/ufld/resnet18_tusimple.py b/configs/ufld/resnet18_tusimple.py index 0d1579d..620f5c9 100644 --- a/configs/ufld/resnet18_tusimple.py +++ b/configs/ufld/resnet18_tusimple.py @@ -13,14 +13,13 @@ griding_num = 100 num_lanes = 6 - - classification = True -num_classes = 3 +num_classes = 7 +autocast = True heads = dict(type='LaneCls', dim = (griding_num + 1, 56, num_lanes), - cat_dim =(num_lanes, num_classes)) + cat_dim =(num_classes, num_lanes)) trainer = dict( type='LaneCls' @@ -38,9 +37,10 @@ ) #optimizer = dict(type='Adam', lr= 0.025, weight_decay = 0.0001) # 3e-4 for batchsize 8 -epochs = 15 -batch_size = 4 -total_iter = (3216 // batch_size + 1) * epochs +epochs = 40 +batch_size = 16 +total_training_samples = 3626 +total_iter = (total_training_samples // batch_size + 1) * epochs import math @@ -117,6 +117,5 @@ eval_ep = 1 save_ep = epochs row_anchor='tusimple_row_anchor' -test_json_file='data/tusimple/label_data_0531.json' -#test_json_file='data/tusimple/label_data_0601_small.json' +test_json_file='data/tusimple/test_label.json' lr_update_by_epoch = False diff --git a/lanedet/datasets/tusimple.py b/lanedet/datasets/tusimple.py index b743da4..1cba339 100644 --- a/lanedet/datasets/tusimple.py +++ b/lanedet/datasets/tusimple.py @@ -17,12 +17,9 @@ from sklearn.metrics import confusion_matrix SPLIT_FILES = { - #'trainval': ['label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'], - #'trainval': ['label_data_0531_small.json'], - 'trainval': ['label_data_0313.json', 'label_data_0601.json'], - 'val': ['label_data_0531.json'], - 'test': ['label_data_0531.json'], - #'val': ['label_data_0601_small.json'] + 'trainval': ['label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'], + 'val': ['test_label.json'], + 'test': ['test_label.json'] } @@ -38,7 +35,7 @@ def load_annotations(self): self.logger.info('Loading TuSimple annotations...') self.data_infos = [] max_lanes = 0 - df = {0:0, 1:1, 2:1, 3:2, 4:2, 5:2, 6:1, 7:0} + df = {0:0, 1:1, 2:2, 3:3, 4:3, 5:4, 6:5, 7:6} for anno_file in self.anno_files: anno_file = osp.join(self.data_root, anno_file) with open(anno_file, 'r') as anno_obj: @@ -115,17 +112,31 @@ def accuracy_fn(self, y_true, y_pred): return acc def evaluate_classification(self, predictions, ground_truth): - score = F.softmax(predictions, dim=2) - y_pred = score.argmax(dim=2) + score = F.softmax(predictions, dim=1) + y_pred = score.argmax(dim=1) return self.accuracy_fn(ground_truth, y_pred) def plot_confusion_matrix(self, y_true, y_pred): cf_matrix = confusion_matrix(y_true, y_pred) - class_names = ('background','solid-yellow', 'solid-white', 'dashed', 'double-dashed','botts\'-dots', 'double-solid-yellow', 'unknown') - + class_names = ('background','solid-yellow', 'solid-white', 'dashed','botts\'-dots', 'unknown') + #class_names = ('background', 'solid', 'dashed') # Create pandas dataframe dataframe = pd.DataFrame(cf_matrix, index=class_names, columns=class_names) + total_number_of_instances = dataframe.sum(1)[1:].sum() + + + df = {0:0, 1:1, 2:1, 3:2, 4:2, 5:1, 6:1} + y_true = list(map(df.get,y_true)) + y_pred = list(map(df.get,y_pred)) + cf_matrix_2 = confusion_matrix(y_true, y_pred) + true_positives_2 = np.diag(cf_matrix_2)[1:].sum() + accuracy_2 = true_positives_2 / total_number_of_instances + print(f"Accuracy for 2 classes: {accuracy_2}") + + true_positives = np.diag(cf_matrix)[1:].sum() + accuracy = true_positives / total_number_of_instances + print(f"Accuracy for 6 classes: {accuracy}") # compute metrices from confusion matrix FP = cf_matrix.sum(axis=0) - np.diag(cf_matrix) diff --git a/lanedet/engine/runner.py b/lanedet/engine/runner.py index 0c81596..37aaab2 100644 --- a/lanedet/engine/runner.py +++ b/lanedet/engine/runner.py @@ -14,7 +14,8 @@ from lanedet.utils.recorder import build_recorder from lanedet.utils.net_utils import save_model, load_network from mmcv.parallel import MMDataParallel - +import torch.nn.functional as F +from torch.cuda.amp import autocast, GradScaler class Runner(object): def __init__(self, cfg): @@ -42,6 +43,7 @@ def __init__(self, cfg): self.classification_metric = 0. self.val_loader = None self.test_loader = None + self.scaler = GradScaler(enabled=True) def resume(self): if not self.cfg.load_from and not self.cfg.finetune_from: @@ -66,11 +68,22 @@ def train_epoch(self, epoch, train_loader): date_time = time.time() - end self.recorder.step += 1 data = self.to_cuda(data) - output = self.net(data) self.optimizer.zero_grad() - loss = output['loss'] - loss.backward() - self.optimizer.step() + + if self.cfg.autocast: + with autocast(enabled=True): + output = self.net(data) + loss = output['loss'] + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + output = self.net(data) + loss = output['loss'] + loss.backward() + self.optimizer.step() + if not self.cfg.lr_update_by_epoch: self.scheduler.step() if self.warmup_scheduler: @@ -125,7 +138,6 @@ def validate(self): detection_metric = detection_out if detection_metric > self.detection_metric: self.detection_metric = detection_metric - self.save_ckpt(is_best=True) if self.cfg.classification: classification_acc /= len(self.val_loader) @@ -133,7 +145,7 @@ def validate(self): classification_metric = classification_acc if classification_metric > self.classification_metric: self.classification_metric = classification_metric - #self.save_ckpt(is_best=True) + self.save_ckpt(is_best=True) self.recorder.logger.info('Best detection metric: ' + str(self.detection_metric) + " " + 'Best classification metric: ' + str(self.classification_metric)) else: self.recorder.logger.info("Detection: " +str(detection_out)) @@ -148,21 +160,29 @@ def test(self): y_pred = [] self.net.eval() detection_predictions = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() for i, data in enumerate(tqdm(self.test_loader, desc=f'test')): data = self.to_cuda(data) with torch.no_grad(): - output = self.net(data) - detection_output = self.net.module.get_lanes(output)['lane_output'] - detection_predictions.extend(detection_output) + with autocast(enabled=self.cfg.autocast): + output = self.net(data) + detection_output = self.net.module.get_lanes(output)['lane_output'] + detection_predictions.extend(detection_output) if self.cfg.classification: y_true.extend((data['category'].cpu().numpy()).flatten('C').tolist()) - score = F.softmax(output['category'].cuda(), dim=2) - score = score.argmax(dim=2) + score = F.softmax(output['category'].cuda(), dim=1) + score = score.argmax(dim=1) y_pred.extend((score.cpu().numpy()).flatten('C').tolist()) classification_acc += self.test_loader.dataset.evaluate_classification(output['category'].cuda(), data['category'].cuda()) + end.record() + torch.cuda.synchronize() + print('execution time in milliseconds per image: {}'. format(start.elapsed_time(end)/2782)) + detection_out = self.test_loader.dataset.evaluate_detection(detection_predictions, self.cfg.work_dir) if self.cfg.classification: diff --git a/lanedet/models/heads/lane_cls.py b/lanedet/models/heads/lane_cls.py index f605212..361f71d 100644 --- a/lanedet/models/heads/lane_cls.py +++ b/lanedet/models/heads/lane_cls.py @@ -28,10 +28,10 @@ def __init__(self, dim, cat_dim, cfg=None): if self.cfg.classification: self.category = torch.nn.Sequential( - torch.nn.Linear(1800, 512), - torch.nn.BatchNorm1d(512), + torch.nn.Linear(1800, 256), + torch.nn.BatchNorm1d(256), torch.nn.ReLU(), - torch.nn.Linear(512, 100), + torch.nn.Linear(256, 100), torch.nn.ReLU(), torch.nn.Linear(100, np.prod(self.cat_dim)) ) @@ -65,16 +65,14 @@ def loss(self, output, batch): total_loss = 0 loss_stats = {} det_loss = criterion(output['det'], batch['cls_label']) + classification_loss_weight = 0.7 if self.cfg.classification: loss_fn = torch.nn.CrossEntropyLoss() - classification_output = output['category'].reshape(self.cfg.batch_size*self.cfg.num_lanes, self.cfg.num_classes) - targets = batch['category'].reshape(self.cfg.batch_size*self.cfg.num_lanes) - - cat_loss = loss_fn(classification_output, targets) + cat_loss = loss_fn(output['category'], batch['category']) loss_stats.update({'det_loss': det_loss, 'cls_loss': cat_loss}) - total_loss = det_loss + cat_loss*0.7 + total_loss = det_loss + cat_loss*classification_loss_weight else: loss_stats.update({'det_loss': det_loss}) total_loss = det_loss @@ -116,7 +114,6 @@ def get_lanes(self, pred): def forward(self, x, **kwargs): x = x[-1] #print(x.shape) - x = self.pool(x).view(-1, 1800) # shape will be batch size x 1800 though det = self.det(x).view(-1, *self.dim) if self.cfg.classification: diff --git a/lanedet/models/heads/lane_seg.py b/lanedet/models/heads/lane_seg.py index 2aefa78..43cd31f 100644 --- a/lanedet/models/heads/lane_seg.py +++ b/lanedet/models/heads/lane_seg.py @@ -30,15 +30,15 @@ def __init__(self, decoder, exist=None, thr=0.6, padding=1, bias=False ) - self.bn1 = torch.nn.BatchNorm2d(out_channels) + #self.bn1 = torch.nn.BatchNorm2d(out_channels) self.relu = torch.nn.ReLU(inplace=True) self.category = torch.nn.Sequential( - torch.nn.Dropout(p=0.2), - torch.nn.Linear(353280, 512), - torch.nn.BatchNorm1d(512), + torch.nn.Dropout(p=0.3), + torch.nn.Linear(353280, 256), + torch.nn.BatchNorm1d(256), torch.nn.ReLU(), - torch.nn.Linear(512, 100), + torch.nn.Linear(256, 100), torch.nn.ReLU(), torch.nn.Linear(100, np.prod(self.cat_dim)) ) @@ -119,11 +119,8 @@ def loss(self, output, batch): if self.cfg.classification: loss_fn = torch.nn.CrossEntropyLoss() - classification_output = output['category'].reshape(self.cfg.batch_size*(self.cfg.num_lanes - 1), self.cfg.num_classes) - targets = batch['category'].reshape(self.cfg.batch_size*(self.cfg.num_lanes - 1)) - - cat_loss = loss_fn(classification_output, targets) - loss += cat_loss*0.7 + cat_loss = loss_fn(output['category'], batch['category']) + loss += cat_loss*0.6 loss_stats.update({'cls_loss': cat_loss}) if 'exist' in output: @@ -145,14 +142,14 @@ def forward(self, x, **kwargs): if self.cfg.classification: x= output['seg'][:,1:, ...] - print(x.shape) + #print(x.shape) x = self.maxpool(x) - print(x.shape) + #print(x.shape) x = self.conv1(x) - print(x.shape) - x = self.bn1(x) + #print(x.shape) + #x = self.bn1(x) x = self.relu(x).view(-1, 353280) - print(x.shape) + #print(x.shape) category = self.category(x).view(-1, *self.cat_dim) output.update({'category': category}) diff --git a/main.py b/main.py index a743ab9..fecffcd 100644 --- a/main.py +++ b/main.py @@ -31,6 +31,8 @@ def main(): if args.validate: runner.validate() + elif args.test: + runner.test() else: runner.train() @@ -53,6 +55,10 @@ def parse_args(): '--validate', action='store_true', help='whether to evaluate the checkpoint during training') + parser.add_argument( + '--test', + action='store_true', + help='whether to evaluate the checkpoint during training') parser.add_argument('--gpus', nargs='+', type=int, default='0') parser.add_argument('--seed', type=int, default=0, help='random seed') diff --git a/tools/detect.py b/tools/detect.py index e49039e..5c62b0c 100644 --- a/tools/detect.py +++ b/tools/detect.py @@ -13,6 +13,7 @@ from pathlib import Path from tqdm import tqdm import torch.nn.functional as F +from torch.cuda.amp import autocast, GradScaler class Detect(object): def __init__(self, cfg): @@ -36,8 +37,9 @@ def preprocess(self, img_path): def inference(self, data): with torch.no_grad(): - data = self.net(data) - lane_detection, lane_indx = self.net.module.get_lanes(data) + with autocast(enabled=self.cfg.autocast): + data = self.net(data) + lane_detection, lane_indx = self.net.module.get_lanes(data) if self.cfg.classification: lane_classes = self.get_lane_class(data, lane_indx) return lane_detection[0], lane_classes