Skip to content

Commit 9867504

Browse files
committed
add weights
1 parent 391a108 commit 9867504

File tree

12 files changed

+94
-104
lines changed

12 files changed

+94
-104
lines changed

alignshift/models/truncated_densenet3d_alignshift.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# Ke Yan, Imaging Biomarkers and Computer-Aided Diagnosis Laboratory,
2-
# National Institutes of Health Clinical Center, July 2019
3-
"""The truncated Densenet-121 with FPN and 3DCE"""
42
from collections import namedtuple
53

64
import torch
@@ -13,7 +11,7 @@
1311
import re
1412
from collections import OrderedDict
1513
from mmdet.models.registry import BACKBONES
16-
from convs.operators import AlignShiftConv
14+
from alignshift.operators import AlignShiftConv
1715
import torch.utils.checkpoint as cp
1816
from mmdet.models.utils import build_conv_layer, build_norm_layer
1917
# mybn = nn.BatchNorm3d
@@ -30,13 +28,13 @@ class _DenseLayer(nn.Sequential):
3028
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, n_fold, memory_efficient=False, ref_thickness=None):
3129
super(_DenseLayer, self).__init__()
3230
self.add_module('norm1', build_norm_layer(norm_cfg, num_input_features, postfix=1)[1]),
33-
self.add_module('relu1', nn.ReLU(inplace=True)),
31+
self.add_module('relu1', nn.ReLU(inplace=False)),
3432
self.add_module('conv1', AlignShiftConv(num_input_features, bn_size *
3533
growth_rate, kernel_size=1, stride=1, alignshift=True,
3634
bias=False, ref_thickness=ref_thickness, n_fold=n_fold,)),
3735
self.add_module('norm2', build_norm_layer(norm_cfg, bn_size * growth_rate, postfix=1)[1]),
38-
self.add_module('relu2', nn.ReLU(inplace=True)),
39-
self.add_module('conv2', AlignShiftConv(bn_size * growth_rate, growth_rate,
36+
self.add_module('relu2', nn.ReLU(inplace=False)),
37+
self.add_module('conv2', AlignShiftConv(bn_size * growth_rate, growth_rate, alignshift=True,
4038
kernel_size=3, stride=1, padding=1,
4139
bias=False, ref_thickness=ref_thickness, n_fold=n_fold)),
4240
self.drop_rate = drop_rate
@@ -77,11 +75,6 @@ def forward(self, init_features, thickness):
7775
features.append(new_features)
7876
return torch.cat(features, 1)
7977

80-
class _StageNormRelu(nn.Sequential):
81-
def __init__(self, num_input_features):
82-
super().__init__()
83-
self.add_module('norm', build_norm_layer(norm_cfg, num_input_features, postfix=1)[1])
84-
self.add_module('relu', nn.ReLU(inplace=True))
8578

8679
class _Transition(nn.Sequential):
8780
def __init__(self, num_input_features, num_output_features):
@@ -101,17 +94,13 @@ def __init__(self, input_features, input_slice):
10194
# self.add_module('reduction_z_pooling', nn.AvgPool3d(kernel_size=[input_slice, 1, 1], stride=1))
10295
@BACKBONES.register_module
10396
class DenseNetCustomTrunc3dAlign(nn.Module):
104-
"""The truncated Densenet-121 with FPN and 3DCE"""
105-
# truncated since transition layer 3 since we find it works better in DeepLesion
106-
# We only keep the finest-level feature map after FPN
10797
def __init__(self,
10898
out_dim=256,
10999
n_cts=3,
110100
fpn_finest_layer=1,
111101
ref_thickness=2.0,
112102
n_fold=8,
113-
memory_efficient=True,
114-
syncbn=True):
103+
memory_efficient=True):
115104
super().__init__()
116105
self.depth = 121
117106
self.feature_upsample = True
@@ -176,8 +165,8 @@ def __init__(self,
176165
nn.init.kaiming_uniform_(layer.weight, a=1)
177166
nn.init.constant_(layer.bias, 0)
178167
self.init_weights()
179-
if syncbn:
180-
self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
168+
# if syncbn:
169+
# self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
181170

182171
def forward(self, x, thickness):
183172
x = self.conv0(x)
@@ -238,8 +227,8 @@ def init_weights(self, pretrained=True):
238227
state_dict1[new_key] = state_dict[key]
239228

240229
key = self.load_state_dict(state_dict1, strict=False)
241-
print(key)
242-
230+
#print(key)
231+
243232
def freeze(self):
244233
for name, param in self.named_parameters():
245234
print('freezing', name)

alignshift/models/truncated_densenet3d_tsm.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212
from collections import OrderedDict
1313
from mmdet.models.registry import BACKBONES
14-
from convs.operators.tsmconv import TSMConv
14+
from alignshift.operators.tsmconv import TSMConv
1515
import torch.utils.checkpoint as cp
1616
from mmdet.models.utils import build_conv_layer, build_norm_layer
1717
# mybn = nn.BatchNorm3d
@@ -29,12 +29,12 @@ class _DenseLayer(nn.Sequential):
2929
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, n_fold, memory_efficient=False):
3030
super(_DenseLayer, self).__init__()
3131
self.add_module('norm1', build_norm_layer(norm_cfg, num_input_features, postfix=1)[1]),
32-
self.add_module('relu1', nn.ReLU(inplace=True)),
32+
self.add_module('relu1', nn.ReLU(inplace=False)),
3333
self.add_module('conv1', TSMConv(num_input_features, bn_size *
3434
growth_rate, kernel_size=1, stride=1,
3535
bias=False, n_fold=n_fold)),
3636
self.add_module('norm2', build_norm_layer(norm_cfg, bn_size * growth_rate, postfix=1)[1]),
37-
self.add_module('relu2', nn.ReLU(inplace=True)),
37+
self.add_module('relu2', nn.ReLU(inplace=False)),
3838
self.add_module('conv2', TSMConv(bn_size * growth_rate, growth_rate,
3939
kernel_size=3, stride=1, padding=1,
4040
bias=False, n_fold=n_fold)),
@@ -83,7 +83,7 @@ def __init__(self, num_input_features, num_output_features):
8383
self.add_module('relu', nn.ReLU(inplace=True))
8484
self.add_module('conv', TSMConv(num_input_features, num_output_features,
8585
kernel_size=1, stride=1, bias=False, tsm=False))
86-
self.add_module('pool', nn.AvgPool3d(kernel_size=[1, 2, 2], stride=[1, 2, 2]))
86+
self.add_module('pool', nn.AvgPool3d(kernel_size=[1, 2, 2], stride=[1, 2, 2]))#, padding=[0,1,1]
8787

8888
class _Reduction_z(nn.Sequential):
8989
def __init__(self, input_features, input_slice):
@@ -98,8 +98,7 @@ def __init__(self,
9898
n_cts=3,
9999
fpn_finest_layer=1,
100100
memory_efficient=True,
101-
n_fold=8,
102-
syncbn=True):
101+
n_fold=8,):
103102
super().__init__()
104103
self.depth = 121
105104
self.feature_upsample = True
@@ -126,16 +125,17 @@ def __init__(self,
126125
# Each denseblock
127126
num_features = num_init_features
128127
for i, num_layers in enumerate(block_config):
129-
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,memory_efficient=memory_efficient,
130-
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, n_fold=n_fold)
128+
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
129+
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate,
130+
n_fold=self.n_fold, memory_efficient=memory_efficient)
131131
self.add_module('denseblock%d' % (i + 1), block)
132132
num_features = num_features + num_layers * growth_rate
133+
reductionz = _Reduction_z(num_features, self.n_cts)
134+
self.add_module('reductionz%d' % (i + 1), reductionz)
133135
if i != len(block_config) - 1:
134136
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
135137
self.add_module('transition%d' % (i + 1), trans)
136138
num_features = num_features // 2
137-
reductionz = _Reduction_z(num_features, self.n_cts)
138-
self.add_module('reductionz%d' % (i + 1), reductionz)
139139

140140
# Final batch norm
141141
# self.add_module('norm5', nn.BatchNorm2d(num_features))
@@ -159,38 +159,40 @@ def __init__(self,
159159
nn.init.kaiming_uniform_(layer.weight, a=1)
160160
nn.init.constant_(layer.bias, 0)
161161
self.init_weights()
162-
if syncbn:
163-
self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
162+
# if syncbn:
163+
# self = nn.SyncBatchNorm.convert_sync_batchnorm(self)
164164

165165
def forward(self, x):
166166
x = self.conv0(x)
167167
x = self.norm0(x)
168-
relu0 = self.relu0(x)
169-
pool0 = self.pool0(relu0)
168+
x = self.relu0(x)
169+
x = self.pool0(x)
170170

171-
db1 = self.denseblock1(pool0)
172-
ts1 = self.transition1(db1)
171+
x = self.denseblock1(x)
172+
redc1 = self.reductionz1(x)
173+
x = self.transition1(x)
173174

174-
db2 = self.denseblock2(ts1)
175-
ts2 = self.transition2(db2)
176175

177-
db3 = self.denseblock3(ts2)
176+
x = self.denseblock2(x)
177+
redc2 = self.reductionz2(x)
178+
x = self.transition2(x)
178179

180+
181+
x = self.denseblock3(x)
182+
redc3 = self.reductionz3(x)
179183
# truncated since here since we find it works better in DeepLesion
180184
# ts3 = self.transition3(db3)
181185
# db4 = self.denseblock4(ts3)
182186

183-
if self.feature_upsample:
184-
ftmaps = [relu0[:,:,self.mid_ct,...], db1[:,:,self.mid_ct,...], db2[:,:,self.mid_ct,...], db3[:,:,self.mid_ct,...]]
185-
x = self.lateral4(ftmaps[-1])
186-
for p in range(3, self.fpn_finest_layer - 1, -1):
187-
x = F.interpolate(x, scale_factor=2, mode="nearest")
188-
y = ftmaps[p-1]
189-
lateral = getattr(self, 'lateral%d' % p)(y)
190-
x += lateral
191-
return [x]
192-
else:
193-
return [db3]
187+
# if self.feature_upsample:
188+
ftmaps = [None, redc1.squeeze(2), redc2.squeeze(2), redc3.squeeze(2)]
189+
x = self.lateral4(ftmaps[-1])
190+
for p in range(3, self.fpn_finest_layer - 1, -1):
191+
x = F.interpolate(x, scale_factor=2, mode="nearest")
192+
y = ftmaps[p-1]
193+
lateral = getattr(self, 'lateral%d' % p)(y)
194+
x += lateral
195+
return [x]
194196

195197
def init_weights(self, pretrained=True):
196198
pattern = re.compile(

alignshift/operators/tsmconv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TSMConv(_ConvNd):
1717
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
1818
padding=0, dilation=1, groups=1,
1919
bias=True, padding_mode='zeros',
20-
n_fold=8, tsm=True, inplace=True,
20+
n_fold=8, tsm=True, inplace=False,
2121
shift_padding_zero=True):
2222

2323
kernel_size = _pair_same(kernel_size)

deeplesion.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ AlignShift:Bridging the Gap of Imaging Thickness in 3D Anisotropic Volumes ([arX
2020
the experiment code is base on [mmdetection](https://github.com/open-mmlab/mmdetection)
2121
,this directory consists of compounents used in mmdetection.
2222
* ``mmdet``
23+
2324
## Convert a 2D model into 3D with a single line of code
2425

2526
```python
File renamed without changes.

deeplesion/dataset/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def __call__(self, results):
5050
if 'gt_semantic_seg' in results:
5151
results['gt_semantic_seg'] = DC(
5252
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
53-
if 'z_spacing' in results:
54-
results['z_spacing'] = DC(to_tensor(results['z_spacing']), stack=True, pad_dims=None)
53+
if 'thickness' in results:
54+
results['thickness'] = DC(to_tensor(results['thickness']), stack=True, pad_dims=None)
5555
return results
5656

5757
def __repr__(self):

deeplesion/mconfigs/densenet_25d.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ..ENVIRON import data_root
1+
from deeplesion.ENVIRON import data_root
22
anchor_scales = [4, 6, 8, 12, 24, 48]#, 64
33
# fp16 = dict(loss_scale=96.)
44

@@ -9,7 +9,7 @@
99
DATA_AUG_POSITION = False,
1010
NORM_SPACING = 0.,
1111
SLICE_INTV = 2.0,
12-
NUM_SLICES = 7,
12+
NUM_SLICES = 3,
1313
GROUNP_ZSAPACING = False,
1414
)
1515
input_channel = dataset_transform['NUM_SLICES']
@@ -21,6 +21,7 @@
2121
pretrained= False,
2222
backbone=dict(
2323
type='DenseNetCustomTrunc',
24+
in_channels=input_channel,
2425
out_dim=512,
2526
fpn_finest_layer=2,),
2627
rpn_head=dict(
@@ -127,7 +128,7 @@
127128
max_per_img=50,
128129
mask_thr_binary=0.5))
129130
# dataset settings
130-
dataset_type = 'DeepLesionDataset_25d'
131+
dataset_type = 'DeepLesionDataset25d'
131132

132133
img_norm_cfg = dict(
133134
mean=[0.] * input_channel, std=[1.] * input_channel, to_rgb=False)
@@ -166,7 +167,7 @@
166167
]
167168
train_pipeline = [
168169
dict(type='DefaultFormatBundle'),
169-
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'z_spacing']),#, 'flage'#, 'img_info'#, 'z_spacing'
170+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'thickness']),#, 'flage'#, 'img_info'#, 'z_spacing'
170171
]
171172

172173
test_pipeline = [
@@ -185,14 +186,14 @@
185186
dicm2png_cfg=dataset_transform),
186187
with_mask=True,
187188
with_label=True,
188-
test=dict(
189+
val=dict(
189190
type=dataset_type,
190191
ann_file=data_root + 'val_ann.pkl',
191192
image_path=data_root + 'Images_png/',
192193
pipeline=train_pipeline,
193194
pre_pipeline = pre_pipeline_test,
194195
dicm2png_cfg=dataset_transform),
195-
val=dict(
196+
test=dict(
196197
type=dataset_type,
197198
ann_file=data_root + 'test_ann.pkl',
198199
image_path=data_root + 'Images_png/',

deeplesion/mconfigs/densenet_align.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
DATA_AUG_POSITION = False,
99
NORM_SPACING = 0.,
1010
SLICE_INTV = 2.0,
11-
NUM_SLICES = 7,
11+
NUM_SLICES = 3,
1212
GROUNP_ZSAPACING = False,
1313
)
1414
input_channel = dataset_transform['NUM_SLICES']

deeplesion/mconfigs/densenet_tsm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DATA_AUG_POSITION = False,
1010
NORM_SPACING = 0.,
1111
SLICE_INTV = 2.0,
12-
NUM_SLICES = 3,
12+
NUM_SLICES = 7,
1313
GROUNP_ZSAPACING = False,
1414
)
1515
input_channel = dataset_transform['NUM_SLICES']
@@ -25,12 +25,11 @@
2525
out_dim=feature_channel,
2626
fpn_finest_layer=2,
2727
n_fold=8,
28-
memory_efficient=True,
29-
syncbn=False),
28+
memory_efficient=True),
3029
rpn_head=dict(
3130
type='RPNHead',
3231
in_channels=feature_channel,
33-
feat_channels=feature_channel,###原版这俩好像没有conv
32+
feat_channels=feature_channel,
3433
anchor_scales=anchor_scales,
3534
anchor_ratios=[0.5, 1., 2.0],
3635
anchor_base_sizes=[4],
@@ -63,7 +62,7 @@
6362
mask_roi_extractor=dict(
6463
type='SingleRoIExtractor',
6564
roi_layer=dict(type='RoIAlign', out_size=14, sample_num=2),
66-
finest_scale=196, #按照面积是这个的倍数, 到该层去取feature
65+
finest_scale=196,
6766
out_channels=feature_channel,
6867
featmap_strides=[4]),
6968
mask_head=dict(
@@ -170,7 +169,7 @@
170169
]
171170
train_pipeline = [
172171
dict(type='DefaultFormatBundle_3d'),
173-
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),#, 'flage'#, 'img_info'#, 'z_spacing', , 'thickness'
172+
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'thickness']),#, 'flage'#, 'img_info'#, 'z_spacing', , 'thickness'
174173
]
175174

176175
test_pipeline = [
@@ -190,14 +189,14 @@
190189
dicm2png_cfg=dataset_transform),
191190
with_mask=True,
192191
with_label=True,
193-
test=dict(
192+
val=dict(
194193
type=dataset_type,
195194
ann_file=data_root + 'val_ann.pkl',
196195
image_path=data_root + 'Images_png/',
197196
pipeline=train_pipeline,
198197
pre_pipeline = pre_pipeline_test,
199198
dicm2png_cfg=dataset_transform),
200-
val=dict(
199+
test=dict(
201200
type=dataset_type,
202201
ann_file=data_root + 'test_ann.pkl',
203202
image_path=data_root + 'Images_png/',

deeplesion/models/truncated_densenet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
class DenseNetCustomTrunc(nn.Module):
1717
def __init__(self,
1818
out_dim=256,
19+
in_channels=3,
1920
fpn_finest_layer=1):
2021
super().__init__()
2122
self.depth = 121
2223
self.feature_upsample = True
2324
self.fpn_finest_layer = fpn_finest_layer
2425
self.out_dim = out_dim
25-
self.in_channel = 7
26+
self.in_channel = in_channels
2627
assert self.depth in [121]
2728
if self.depth == 121:
2829
num_init_features = 64

0 commit comments

Comments
 (0)