@@ -40,23 +40,39 @@ def __init__(
40
40
aa_layer = None ,
41
41
drop_block = None ,
42
42
drop_path = None ,
43
+ device = None ,
44
+ dtype = None ,
43
45
):
46
+ dd = {'device' : device , 'dtype' : dtype }
44
47
super (SelectiveKernelBasic , self ).__init__ ()
45
48
46
49
sk_kwargs = sk_kwargs or {}
47
- conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer )
50
+ conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer , ** dd )
48
51
assert cardinality == 1 , 'BasicBlock only supports cardinality of 1'
49
52
assert base_width == 64 , 'BasicBlock doest not support changing base width'
50
53
first_planes = planes // reduce_first
51
54
outplanes = planes * self .expansion
52
55
first_dilation = first_dilation or dilation
53
56
54
57
self .conv1 = SelectiveKernel (
55
- inplanes , first_planes , stride = stride , dilation = first_dilation ,
56
- aa_layer = aa_layer , drop_layer = drop_block , ** conv_kwargs , ** sk_kwargs )
58
+ inplanes ,
59
+ first_planes ,
60
+ stride = stride ,
61
+ dilation = first_dilation ,
62
+ aa_layer = aa_layer ,
63
+ drop_layer = drop_block ,
64
+ ** conv_kwargs ,
65
+ ** sk_kwargs ,
66
+ )
57
67
self .conv2 = ConvNormAct (
58
- first_planes , outplanes , kernel_size = 3 , dilation = dilation , apply_act = False , ** conv_kwargs )
59
- self .se = create_attn (attn_layer , outplanes )
68
+ first_planes ,
69
+ outplanes ,
70
+ kernel_size = 3 ,
71
+ dilation = dilation ,
72
+ apply_act = False ,
73
+ ** conv_kwargs ,
74
+ )
75
+ self .se = create_attn (attn_layer , outplanes , ** dd )
60
76
self .act = act_layer (inplace = True )
61
77
self .downsample = downsample
62
78
self .drop_path = drop_path
@@ -101,22 +117,33 @@ def __init__(
101
117
aa_layer = None ,
102
118
drop_block = None ,
103
119
drop_path = None ,
120
+ device = None ,
121
+ dtype = None ,
104
122
):
123
+ dd = {'device' : device , 'dtype' : dtype }
105
124
super (SelectiveKernelBottleneck , self ).__init__ ()
106
125
107
126
sk_kwargs = sk_kwargs or {}
108
- conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer )
127
+ conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer , ** dd )
109
128
width = int (math .floor (planes * (base_width / 64 )) * cardinality )
110
129
first_planes = width // reduce_first
111
130
outplanes = planes * self .expansion
112
131
first_dilation = first_dilation or dilation
113
132
114
133
self .conv1 = ConvNormAct (inplanes , first_planes , kernel_size = 1 , ** conv_kwargs )
115
134
self .conv2 = SelectiveKernel (
116
- first_planes , width , stride = stride , dilation = first_dilation , groups = cardinality ,
117
- aa_layer = aa_layer , drop_layer = drop_block , ** conv_kwargs , ** sk_kwargs )
135
+ first_planes ,
136
+ width ,
137
+ stride = stride ,
138
+ dilation = first_dilation ,
139
+ groups = cardinality ,
140
+ aa_layer = aa_layer ,
141
+ drop_layer = drop_block ,
142
+ ** conv_kwargs ,
143
+ ** sk_kwargs ,
144
+ )
118
145
self .conv3 = ConvNormAct (width , outplanes , kernel_size = 1 , apply_act = False , ** conv_kwargs )
119
- self .se = create_attn (attn_layer , outplanes )
146
+ self .se = create_attn (attn_layer , outplanes , ** dd )
120
147
self .act = act_layer (inplace = True )
121
148
self .downsample = downsample
122
149
self .drop_path = drop_path
0 commit comments