Skip to content
This repository has been archived by the owner on Jun 15, 2022. It is now read-only.

Commit

Permalink
Merge pull request #15 from DeNA/feature/apply_config_to_model
Browse files Browse the repository at this point in the history
Feature/apply config to model
  • Loading branch information
hirotomusiker authored Jan 7, 2019
2 parents 34d9670 + 17e59ed commit fefc12f
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 30 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ We use the following format:
MODEL:
TYPE: YOLOv3
BACKBONE: darknet53
ANCHORS: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]] # the anchors used in the YOLO layers
ANCH_MASK: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] # anchor filter for each YOLO layer
N_CLASSES: 80 # number of object classes
TRAIN:
LR: 0.001
MOMENTUM: 0.9
Expand Down
5 changes: 5 additions & 0 deletions config/yolov3_default.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
MODEL:
TYPE: YOLOv3
BACKBONE: darknet53
ANCHORS: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
ANCH_MASK: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
N_CLASSES: 80
TRAIN:
LR: 0.001
MOMENTUM: 0.9
Expand Down
5 changes: 5 additions & 0 deletions config/yolov3_eval.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
MODEL:
TYPE: YOLOv3
BACKBONE: darknet53
ANCHORS: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
ANCH_MASK: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
N_CLASSES: 80
TRAIN:
LR: 0.00
MOMENTUM: 0.9
Expand Down
6 changes: 4 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def main():
cfg = yaml.load(f)

imgsize = cfg['TEST']['IMGSIZE']
model = YOLOv3()
model = YOLOv3(cfg['MODEL'])

confthre = cfg['TEST']['CONFTHRE']
nmsthre = cfg['TEST']['NMSTHRE']

Expand All @@ -41,7 +42,8 @@ def main():

img = cv2.imread(args.image)
img_raw = img.copy()[:, :, ::-1].transpose((2, 0, 1))
img, info_img = preprocess(img, imgsize) # info = (h, w, nh, nw, dx, dy)
img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy)
img = np.transpose(img / 255., (2, 0, 1))
img = torch.from_numpy(img).float().unsqueeze(0)

if args.gpu >= 0:
Expand Down
38 changes: 20 additions & 18 deletions models/yolo_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,39 @@ class YOLOLayer(nn.Module):
"""
detection layer corresponding to yolo_layer.c of darknet
"""
def __init__(self, anch_mask, n_classes, stride, in_ch=1024, ignore_thre=0.7):
def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7):
"""
Args:
anch_mask (list of int): index indicating the anchors to be used in this layer.
n_classes (int): number of classes
stride (int): the corresponding pixel number of the input image for \
one pixel in the feature map at the scale.
config_model (dict) : model configuration.
ANCHORS (list of tuples) :
ANCH_MASK: (list of int list): index indicating the anchors to be
used in YOLO layers. One of the mask group is picked from the list.
N_CLASSES (int): number of classes
layer_no (int): YOLO layer number - one from (0, 1, 2).
in_ch (int): number of input channels.
ignore_thre (float): threshold of IoU above which objectness training is ignored.
"""

super(YOLOLayer, self).__init__()
self.conv = nn.Conv2d(in_channels=in_ch,
out_channels=255, kernel_size=1, stride=1, padding=0)
self.anchors = [
(10, 13), (16, 30), (33, 23),
(30, 61), (62, 45), (59, 119),
(116, 90), (156, 198), (373, 326)]
self.anch_mask = anch_mask
self.n_anchors = 3
self.n_classes = n_classes
strides = [32, 16, 8] # fixed
self.anchors = config_model['ANCHORS']
self.anch_mask = config_model['ANCH_MASK'][layer_no]
self.n_anchors = len(self.anch_mask)
self.n_classes = config_model['N_CLASSES']
self.ignore_thre = ignore_thre
self.l2_loss = nn.MSELoss(size_average=False)
self.bce_loss = nn.BCELoss(size_average=False)
self.stride = stride
self.all_anchors_grid = [(w / stride, h / stride)
self.stride = strides[layer_no]
self.all_anchors_grid = [(w / self.stride, h / self.stride)
for w, h in self.anchors]
self.masked_anchors = [self.all_anchors_grid[i]
for i in self.anch_mask]
self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4))
self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid)
self.ref_anchors = torch.FloatTensor(self.ref_anchors)
self.conv = nn.Conv2d(in_channels=in_ch,
out_channels=self.n_anchors * (self.n_classes + 5),
kernel_size=1, stride=1, padding=0)

def forward(self, xin, labels=None):
"""
Expand Down Expand Up @@ -178,11 +180,11 @@ class (float): class index.
# loss calculation

output[..., 4] *= obj_mask
output[..., np.r_[0:4, 5:85]] *= tgt_mask
output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
output[..., 2:4] *= tgt_scale

target[..., 4] *= obj_mask
target[..., np.r_[0:4, 5:85]] *= tgt_mask
target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
target[..., 2:4] *= tgt_scale

bceloss = nn.BCELoss(weight=tgt_scale*tgt_scale,
Expand Down
23 changes: 14 additions & 9 deletions models/yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,17 @@ def forward(self, x):
return x


def create_yolov3_modules(ignore_thre):
def create_yolov3_modules(config_model, ignore_thre):
"""
Build yolov3 layer modules.
Args:
config_model (dict): model configuration.
See YOLOLayer class for details.
ignore_thre (float): used in YOLOLayer.
Returns:
mlist (ModuleList): YOLOv3 module list.
"""

# DarkNet53
mlist = nn.ModuleList()
mlist.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))
Expand All @@ -82,8 +85,7 @@ def create_yolov3_modules(ignore_thre):
# 1st yolo branch
mlist.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1))
mlist.append(
YOLOLayer(anch_mask=[6, 7, 8], n_classes=80, stride=32, in_ch=1024,
ignore_thre=ignore_thre))
YOLOLayer(config_model, layer_no=0, in_ch=1024, ignore_thre=ignore_thre))

mlist.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1))
mlist.append(nn.Upsample(scale_factor=2, mode='nearest'))
Expand All @@ -94,17 +96,15 @@ def create_yolov3_modules(ignore_thre):
# 2nd yolo branch
mlist.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1))
mlist.append(
YOLOLayer(anch_mask=[3, 4, 5], n_classes=80, stride=16, in_ch=512,
ignore_thre=ignore_thre))
YOLOLayer(config_model, layer_no=1, in_ch=512, ignore_thre=ignore_thre))

mlist.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1))
mlist.append(nn.Upsample(scale_factor=2, mode='nearest'))
mlist.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1))
mlist.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1))
mlist.append(resblock(ch=256, nblocks=2, shortcut=False))
mlist.append(
YOLOLayer(anch_mask=[0, 1, 2], n_classes=80, stride=8, in_ch=256,
ignore_thre=ignore_thre))
YOLOLayer(config_model, layer_no=2, in_ch=256, ignore_thre=ignore_thre))

return mlist

Expand All @@ -115,14 +115,19 @@ class YOLOv3(nn.Module):
The network returns loss values from three YOLO layers during training \
and detection results during test.
"""
def __init__(self, ignore_thre=0.7):
def __init__(self, config_model, ignore_thre=0.7):
"""
Initialization of YOLOv3 class.
Args:
config_model (dict): used in YOLOLayer.
ignore_thre (float): used in YOLOLayer.
"""
super(YOLOv3, self).__init__()
self.module_list = create_yolov3_modules(ignore_thre)

if config_model['TYPE'] == 'YOLOv3':
self.module_list = create_yolov3_modules(config_model, ignore_thre)
else:
raise Exception('Model name {} is not available'.format(config_model['TYPE']))

def forward(self, x, targets=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main():
base_lr = lr

# Initiate model
model = YOLOv3(ignore_thre=ignore_thre)
model = YOLOv3(cfg['MODEL'], ignore_thre=ignore_thre)

if args.weights_path:
print("loading darknet weights....", args.weights_path)
Expand Down

0 comments on commit fefc12f

Please sign in to comment.