Skip to content

Commit

Permalink
add collect_key for condlane training
Browse files Browse the repository at this point in the history
  • Loading branch information
Turoad committed Mar 30, 2022
1 parent d36ba73 commit 95c06c2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion configs/condlane/resnet101_culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
),
#dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img', 'gt_hm']),
dict(type='ToTensor', keys=['img', 'gt_hm'], collect_keys=['img_metas']),
]


Expand Down
2 changes: 1 addition & 1 deletion configs/condlane/resnet50_culane.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
),
#dict(type='Resize', size=(img_width, img_height)),
dict(type='Normalize', img_norm=img_norm),
dict(type='ToTensor', keys=['img', 'gt_hm']),
dict(type='ToTensor', keys=['img', 'gt_hm'], collect_keys=['img_metas']),
]


Expand Down
11 changes: 8 additions & 3 deletions lanedet/datasets/process/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,22 @@ class ToTensor(object):
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
collect_keys (Sequence[str]): Keys that need to keep, but not to Tensor.
"""

def __init__(self, keys=['img', 'mask'], cfg=None):
def __init__(self, keys=['img', 'mask'], collect_keys=[], cfg=None):
self.keys = keys
self.collect_keys = collect_keys

def __call__(self, sample):
data = {}
if len(sample['img'].shape) < 3:
sample['img'] = np.expand_dims(sample['img'], -1)
for key in self.keys:
data[key] = to_tensor(sample[key])
for key in sample.keys():
if key in self.keys:
data[key] = to_tensor(sample[key])
if key in self.collect_keys:
data[key] = sample[key]
data['img'] = data['img'].permute(2, 0, 1)
return data

Expand Down
1 change: 1 addition & 0 deletions lanedet/models/heads/condlane.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def __init__(self,

def loss(self, output, batch):
img_metas = batch['img_metas']
batch.pop('meta')
return self.loss_impl(output, img_metas, **batch)


Expand Down

0 comments on commit 95c06c2

Please sign in to comment.