Skip to content

Commit

Permalink
updated all codes with vis
Browse files Browse the repository at this point in the history
  • Loading branch information
zillur-av committed Apr 19, 2023
1 parent e6fe576 commit 25e97aa
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 68 deletions.
6 changes: 6 additions & 0 deletions configs/resa/resa18_tusimple.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@
dict(type='ToTensor', keys=['img']),
]

infer_process = [
dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img']),
]

dataset_path = './data/tusimple'
dataset = dict(
train=dict(
Expand Down
21 changes: 13 additions & 8 deletions configs/ufld/resnet18_tusimple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
featuremap_out_channel = 512

griding_num = 100
num_classes = 6
num_lanes = 6


classification = True
num_classes = 8

heads = dict(type='LaneCls',
dim = (griding_num + 1, 56, num_classes))
dim = (griding_num + 1, 56, num_lanes),
cat_dim =(num_lanes, num_classes))

trainer = dict(
type='LaneCls'
Expand All @@ -31,18 +36,18 @@
weight_decay = 1e-4,
momentum = 0.9
)
#optimizer = dict(type='Adam', lr= 0.025, weight_decay = 0.0001) # 3e-4 for batchsize 8

epochs = 30
batch_size = 4
total_iter = (3216 // batch_size + 1) * epochs

import math

scheduler = dict(
type = 'LambdaLR',
lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9)
)


img_norm = dict(
mean=[103.939, 116.779, 123.68],
std=[1., 1., 1.]
Expand All @@ -64,24 +69,24 @@
dict(type='RandomUDoffsetLABEL', max_offset=100),
dict(type='RandomLROffsetLABEL', max_offset=200),
dict(type='GenerateLaneCls', row_anchor=row_anchor,
num_cols=griding_num, num_classes=num_classes),
num_cols=griding_num, num_lanes=num_lanes),
dict(type='Resize', size=(img_w, img_h)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img', 'cls_label']),
]

val_process = [
dict(type='GenerateLaneCls', row_anchor=row_anchor,
num_cols=griding_num, num_classes=num_classes),
num_cols=griding_num, num_lanes=num_lanes),
dict(type='Resize', size=(img_w, img_h)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img', 'cls_label']),
]

test_process = [
infer_process = [
dict(type='Resize', size=(img_w, img_h)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img', 'cls_label']),
dict(type='ToTensor', keys=['img']),
]

dataset = dict(
Expand Down
4 changes: 2 additions & 2 deletions lanedet/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lanedet.utils.visualization import imshow_lanes
from mmcv.parallel import DataContainer as DC
import inspect
from torch.nn.functional import one_hot

@DATASETS.register_module
class BaseDataset(Dataset):
Expand Down Expand Up @@ -40,7 +41,6 @@ def __len__(self):
def __getitem__(self, idx):
'Generates one sample of data'
data_info = self.data_infos[idx]
#print(data_info)
if not osp.isfile(data_info['img_path']):
raise FileNotFoundError('cannot find file: {}'.format(data_info['img_path']))
img = cv2.imread(data_info['img_path'])
Expand All @@ -66,6 +66,6 @@ def __getitem__(self, idx):

category = data_info['categories']
category = [0 if np.all(sample['cls_label'][:,i].numpy() == 100) else category[i] for i in range(6)]
sample['category'] = torch.tensor(category)
sample['category'] = torch.LongTensor(category)

return sample
10 changes: 5 additions & 5 deletions lanedet/datasets/process/generate_lane_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def _grid_pts(pts, num_cols, w):

@PROCESS.register_module
class GenerateLaneCls(object):
def __init__(self, row_anchor, num_cols, num_classes, cfg):
def __init__(self, row_anchor, num_cols, num_lanes, cfg):
self.row_anchor = eval(row_anchor)
self.num_cols = num_cols #100
self.num_classes = num_classes #6
self.num_lanes = num_lanes #6

def __call__(self, sample):
label = sample['mask'] # seg_mask
Expand All @@ -57,12 +57,12 @@ def __call__(self, sample):
scale_f = lambda x : int((x * 1.0/288) * h)
sample_tmp = list(map(scale_f, self.row_anchor)) # list [160, ..... 710]

all_idx = np.zeros((self.num_classes, len(sample_tmp),2)) # 6x56x2
all_idx = np.zeros((self.num_lanes, len(sample_tmp),2)) # 6x56x2

for i,r in enumerate(sample_tmp):
label_r = np.asarray(label)[int(round(r))] # 1280 pixels in each row anchor: shape = 1280x1
# pixels are actually lane numbers like 1 to 6
for lane_idx in range(1, self.num_classes+1): # 1 to 6
for lane_idx in range(1, self.num_lanes+1): # 1 to 6
pos = np.where(label_r == lane_idx)[0] # x pixels of the lane location
if len(pos) == 0:
all_idx[lane_idx - 1, i, 0] = r # if no lane, just put y values like 160, 170
Expand All @@ -73,7 +73,7 @@ def __call__(self, sample):
all_idx[lane_idx - 1, i, 1] = pos # in x values, put mean of x pixels of the lane

all_idx_cp = all_idx.copy()
for i in range(self.num_classes):
for i in range(self.num_lanes):
if np.all(all_idx_cp[i,:,1] == -1): # if all x values are -1, ignore that lane
continue

Expand Down
4 changes: 2 additions & 2 deletions lanedet/datasets/tusimple.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,6 @@ def accuracy_fn(self, y_true, y_pred):
return acc

def evaluate_classification(self, predictions, ground_truth):
score = F.softmax(predictions, dim=1)
y_pred = score.argmax(dim=1)
score = F.softmax(predictions, dim=2)
y_pred = score.argmax(dim=2)
return self.accuracy_fn(ground_truth, y_pred)
30 changes: 16 additions & 14 deletions lanedet/engine/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, cfg):
# self.net, device_ids = range(self.cfg.gpus)).cuda()
self.net = MMDataParallel(
self.net, device_ids = range(self.cfg.gpus)).cuda()
#self.recorder.logger.info('Network: \n' + str(self.net))
self.recorder.logger.info('Network: \n' + str(self.net))
self.resume()
self.optimizer = build_optimizer(self.cfg, self.net)
self.scheduler = build_scheduler(self.cfg, self.optimizer)
Expand Down Expand Up @@ -125,29 +125,31 @@ def validate(self):
data = self.to_cuda(data)
with torch.no_grad():
output = self.net(data)
#print(output.keys())
detection_output = self.net.module.get_lanes(output)
detection_output, _ = self.net.module.get_lanes(output)
detection_predictions.extend(detection_output)
classification_acc += self.val_loader.dataset.evaluate_classification(output['category'].cuda(), data['category'].cuda())
if self.cfg.classification:
classification_acc += self.val_loader.dataset.evaluate_classification(output['category'].cuda(), data['category'].cuda())
if self.cfg.view:
self.val_loader.dataset.view(detection_output, data['meta'])

classification_acc /= len(self.val_loader)

detection_out = self.val_loader.dataset.evaluate_detection(detection_predictions, self.cfg.work_dir)
self.recorder.logger.info("Detection: " +str(detection_out) + " "+ "classification accuracy: " + str(classification_acc))
detection_metric = detection_out
if detection_metric > self.detection_metric:
self.detection_metric = detection_metric
self.save_ckpt(is_best=True)

classification_metric = classification_acc
if classification_metric > self.classification_metric:
self.classification_metric = classification_metric
#self.save_ckpt(is_best=True)

self.recorder.logger.info('Best detection metric: ' + str(self.detection_metric) + " " + 'Best classification metric: ' + str(self.classification_metric))

if self.cfg.classification:
classification_acc /= len(self.val_loader)
self.recorder.logger.info("Detection: " +str(detection_out) + " "+ "classification accuracy: " + str(classification_acc))
classification_metric = classification_acc
if classification_metric > self.classification_metric:
self.classification_metric = classification_metric
#self.save_ckpt(is_best=True)
self.recorder.logger.info('Best detection metric: ' + str(self.detection_metric) + " " + 'Best classification metric: ' + str(self.classification_metric))
else:
self.recorder.logger.info("Detection: " +str(detection_out))
self.recorder.logger.info('Best detection metric: ' + str(self.detection_metric))

def save_ckpt(self, is_best=False):
save_model(self.net, self.optimizer, self.scheduler,
self.recorder, is_best)
62 changes: 37 additions & 25 deletions lanedet/models/heads/lane_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,31 @@

@HEADS.register_module
class LaneCls(nn.Module):
def __init__(self, dim, cfg=None):
def __init__(self, dim, cat_dim, cfg=None):
super(LaneCls, self).__init__()
self.cfg = cfg
chan = cfg.featuremap_out_channel
self.pool = torch.nn.Conv2d(chan, 8, 1)
self.cat_dim = (8, 6)
self.cat_dim = cat_dim
self.dim = dim
self.total_dim = np.prod(dim)

self.cls = torch.nn.Sequential(
self.det = torch.nn.Sequential(
torch.nn.Linear(1800, 2048),
torch.nn.ReLU(),
torch.nn.Linear(2048, self.total_dim),
)

self.category = torch.nn.Sequential(
torch.nn.Linear(1800, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 48)
)


if self.cfg.classification:
self.category = torch.nn.Sequential(
torch.nn.Linear(1800, 512),
torch.nn.BatchNorm1d(512),
torch.nn.ReLU(),
torch.nn.Linear(512, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, np.prod(self.cat_dim))
)

def postprocess(self, out, localization_type='rel', flip_updown=True):
predictions = []
griding_num = self.cfg.griding_num
Expand Down Expand Up @@ -61,31 +64,38 @@ def loss(self, output, batch):
criterion = SoftmaxFocalLoss(2)
total_loss = 0
loss_stats = {}
det_loss = criterion(output['cls'], batch['cls_label'])

loss_fn = torch.nn.CrossEntropyLoss()
#print(output['category'].shape, batch['category'].shape)
#print(output['cls'].shape, batch['cls_label'].shape)
score = F.softmax(output['category'], dim=1)
cat_loss = loss_fn(score, batch['category'])
det_loss = criterion(output['det'], batch['cls_label'])

loss_stats.update({'det_loss': det_loss, 'cat_loss': cat_loss})
total_loss = det_loss + cat_loss
if self.cfg.classification:
loss_fn = torch.nn.CrossEntropyLoss()
classification_output = output['category'].reshape(self.cfg.batch_size*self.cfg.num_lanes, self.cfg.num_classes)
score = F.softmax(classification_output, dim=1)
targets = batch['category'].reshape(self.cfg.batch_size*self.cfg.num_lanes)

cat_loss = loss_fn(score, targets)

loss_stats.update({'det_loss': det_loss, 'cls_loss': cat_loss})
total_loss = det_loss + cat_loss
else:
loss_stats.update({'det_loss': det_loss})
total_loss = det_loss

ret = {'loss': total_loss , 'loss_stats': loss_stats}

return ret

def get_lanes(self, pred):
predictions = self.postprocess(pred['cls'])
predictions = self.postprocess(pred['det'])
ret = []
griding_num = self.cfg.griding_num
sample_y = list(self.cfg.sample_y)
for out in predictions:
lane_indx = []
lanes = []
for i in range(out.shape[1]):
if sum(out[:, i] != 0) <= 2: continue
out_i = out[:, i]
lane_indx.append(i)
coord = []
for k in range(out.shape[0]):
if out[k, i] <= 0: continue
Expand All @@ -98,13 +108,15 @@ def get_lanes(self, pred):
coord[:, 1] /= self.cfg.ori_img_h
lanes.append(Lane(coord))
ret.append(lanes)
return ret
return ret, lane_indx

def forward(self, x, **kwargs):
x = x[-1]
x = self.pool(x).view(-1, 1800)
cls = self.cls(x).view(-1, *self.dim)
category = self.category(x).view(-1, *self.cat_dim)
output = {'cls': cls, 'category': category}

det = self.det(x).view(-1, *self.dim)
if self.cfg.classification:
category = self.category(x).view(-1, *self.cat_dim)
output = {'det': det, 'category': category}
else:
output = {'det': det}
return output
37 changes: 34 additions & 3 deletions lanedet/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,44 @@
import os
import os.path as osp

def imshow_lanes(img, lanes, show=False, out_file=None):
for lane in lanes:
# Color palette for lane visualization
def getcolor(code):
if code == 1:
return (0, 255, 0)
if code == 2:
return (0, 255, 255)
if code == 3:
return (255, 255, 0)
if code == 4:
return (255, 0, 0)
if code == 5:
return (0, 0, 255)
if code == 6:
return (45, 88, 200)
if code == 7:
return (213, 22, 224)


def imshow_lanes(img, lanes, show=False, out_file=None, lane_classes = None):
for i, lane in enumerate(lanes):
for x, y in lane:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
if lane_classes is not None:
color = getcolor(lane_classes[i])
else:
color = (255, 0, 0)
cv2.circle(img, (x, y), 4, color, 2)

if lane_classes is not None:
cv2.putText(img,'solid-yellow',(0,40), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(1),2,cv2.LINE_AA)
cv2.putText(img,'solid-white',(0,70), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(2),2,cv2.LINE_AA)
cv2.putText(img,'dashed',(0,100), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(3),2,cv2.LINE_AA)
cv2.putText(img,'double-dashed',(0,130), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(4),2,cv2.LINE_AA)
cv2.putText(img,'Botts\'-dots',(0,170), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(5),2,cv2.LINE_AA)
cv2.putText(img,'double-solid-yellow',(0,200), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(6),2,cv2.LINE_AA)
cv2.putText(img,'unknown',(0,230), cv2.FONT_HERSHEY_SIMPLEX, 1,getcolor(7),2,cv2.LINE_AA)

if show:
cv2.imshow('view', img)
Expand Down
Loading

0 comments on commit 25e97aa

Please sign in to comment.