Skip to content

Commit df2135c

Browse files
authored
Merge pull request #196 from SamitHuang/main
DBNet output changed in inference mode for faster speed - only binary map is output
2 parents ac357ec + f517821 commit df2135c

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

mindocr/models/heads/det_db_head.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self, in_channels: int, k=50, adaptive=False, bias=False, weight_in
77
self.adaptive = adaptive
88

99
self.segm = self._init_heatmap(in_channels, in_channels // 4, weight_init, bias)
10-
if adaptive:
10+
if self.adaptive:
1111
self.thresh = self._init_heatmap(in_channels, in_channels // 4, weight_init, bias)
1212
self.k = k
1313
self.diff_bin = nn.Sigmoid()
@@ -29,11 +29,20 @@ def _init_heatmap(in_channels, inter_channels, weight_init, bias):
2929
])
3030

3131
def construct(self, features):
32+
'''
33+
Args:
34+
features (Tensor): features output by backbone
35+
Returns:
36+
pred (dict):
37+
- binary: predicted binary map
38+
- thresh: predicted threshold map, only used if adaptive is True in training
39+
- thres_binary: differentiable binary map, only if adaptive is True in training
40+
'''
3241
pred = {'binary': self.segm(features)}
3342

34-
if self.adaptive:
43+
if self.adaptive and self.training:
44+
# only use binary map to derive polygons in inference
3545
pred['thresh'] = self.thresh(features)
3646
pred['thresh_binary'] = self.diff_bin(
3747
self.k * (pred['binary'] - pred['thresh'])) # Differentiable Binarization
38-
3948
return pred

mindocr/postprocess/det_postprocess.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ def __call__(self, pred):
3030
polygons: np.ndarray of shape (N, K, 4, 2) for the polygons of objective regions if region_type is 'quad'
3131
scores: np.ndarray of shape (N, K), score for each box
3232
"""
33-
pred = (pred[self._name] if isinstance(pred, dict) else pred[self._names[self._name]]).squeeze(1)
34-
pred = pred.asnumpy()
33+
if isinstance(pred, dict):
34+
pred = pred[self._name]
35+
elif isinstance(pred, tuple):
36+
pred = pred[self._names[self._name]]
37+
pred = pred.squeeze(1).asnumpy()
38+
3539
segmentation = pred >= self._binary_thresh
3640

3741
# FIXME: dest_size is supposed to be the original image shape (pred.shape -> batch['shape'])

0 commit comments

Comments
 (0)