Skip to content

Commit 34da872

Browse files
fix tf conversion in new v6 models (#5153)
* fix `tf` conversion in new v6 (#5147) * sort imports Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
1 parent 956be8e commit 34da872

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

models/tf.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch.nn as nn
2929
from tensorflow import keras
3030

31-
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
31+
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
3232
from models.experimental import CrossConv, MixConv2d, attempt_load
3333
from models.yolo import Detect
3434
from utils.general import make_divisible, print_args, set_logging
@@ -183,6 +183,22 @@ def call(self, inputs):
183183
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
184184

185185

186+
class TFSPPF(keras.layers.Layer):
187+
# Spatial pyramid pooling-Fast layer
188+
def __init__(self, c1, c2, k=5, w=None):
189+
super(TFSPPF, self).__init__()
190+
c_ = c1 // 2 # hidden channels
191+
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
192+
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
193+
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
194+
195+
def call(self, inputs):
196+
x = self.cv1(inputs)
197+
y1 = self.m(x)
198+
y2 = self.m(y1)
199+
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
200+
201+
186202
class TFDetect(keras.layers.Layer):
187203
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
188204
super(TFDetect, self).__init__()
@@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
272288
pass
273289

274290
n = max(round(n * gd), 1) if n > 1 else n # depth gain
275-
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
291+
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
276292
c1, c2 = ch[f], args[0]
277293
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
278294

0 commit comments

Comments
 (0)