Skip to content

Network head output type changed to fix CoreSegment error in graph mo… #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions mindocr/losses/det_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Tuple, Union
from mindspore import nn, ops
import mindspore as ms
from mindspore import Tensor
import mindspore.numpy as mnp

__all__ = ['L1BalancedCELoss']
Expand Down Expand Up @@ -28,26 +30,45 @@ def __init__(self, eps=1e-6, bce_scale=5, l1_scale=10, bce_replace="bceloss"):
self.l1_scale = l1_scale
self.bce_scale = bce_scale

def construct(self, pred, gt, gt_mask, thresh_map, thresh_mask):
def construct(self, pred: Union[Tensor, Tuple[Tensor]], gt: Tensor, gt_mask: Tensor, thresh_map: Tensor, thresh_mask: Tensor):
"""
pred: A dict which contains predictions.
thresh: The threshold prediction
binary: The text segmentation prediction.
thresh_binary: Value produced by `step_function(binary - thresh)`.
gt: Text regions bitmap gt.
mask: Ignore mask, pexels where value is 1 indicates no contribution to loss.
thresh_mask: Mask indicates regions cared by thresh supervision.
thresh_map: Threshold gt.
Compute dbnet loss
Args:
pred (Tuple[Tensor]): network prediction consists of
binary: The text segmentation prediction.
thresh: The threshold prediction (optional)
thresh_binary: Value produced by `step_function(binary - thresh)`. (optional)
gt (Tensor): Text regions bitmap gt.
mask (Tensor): Ignore mask, pexels where value is 1 indicates no contribution to loss.
thresh_mask (Tensor): Mask indicates regions cared by thresh supervision.
thresh_map (Tensor): Threshold gt.
Return:
loss value (Tensor)
"""
bce_loss_output = self.bce_loss(pred['binary'], gt, gt_mask)
if isinstance(pred, ms.Tensor):
loss = self.bce_loss(pred, gt, gt_mask)
else:
binary, thresh, thresh_binary = pred
bce_loss_output = self.bce_loss(binary, gt, gt_mask)
l1_loss = self.l1_loss(thresh, thresh_map, thresh_mask)
dice_loss = self.dice_loss(thresh_binary, gt, gt_mask)
loss = dice_loss + self.l1_scale * l1_loss + self.bce_scale * bce_loss_output

if 'thresh' in pred:
l1_loss = self.l1_loss(pred['thresh'], thresh_map, thresh_mask)
dice_loss = self.dice_loss(pred['thresh_binary'], gt, gt_mask)
'''
if isinstance(pred, tuple):
binary, thresh, thresh_binary = pred
else:
binary = pred

bce_loss_output = self.bce_loss(binary, gt, gt_mask)

if isinstance(pred, tuple):
l1_loss = self.l1_loss(thresh, thresh_map, thresh_mask)
dice_loss = self.dice_loss(thresh_binary, gt, gt_mask)
loss = dice_loss + self.l1_scale * l1_loss + self.bce_scale * bce_loss_output
else:
loss = bce_loss_output

'''
return loss


Expand Down
31 changes: 16 additions & 15 deletions mindocr/losses/rec_loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import mindspore as ms
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import LossBase
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore import ops
from mindspore import ops
import numpy as np

__all__ = ['CTCLoss']
Expand All @@ -14,8 +15,8 @@ class CTCLoss(LossBase):
CTCLoss definition

Args:
pred_seq_len(int): the length of the predicted character sequence. For text images, this value equals to W - the width of feature map encoded by the visual bacbkone. This can be obtained by probing the output shape in the network.
E.g., for a training image in shape (3, 32, 100), the feature map encoded by resnet34 bacbkone is in shape (512, 1, 4), W = 4, sequence len is 4.
pred_seq_len(int): the length of the predicted character sequence. For text images, this value equals to W - the width of feature map encoded by the visual bacbkone. This can be obtained by probing the output shape in the network.
E.g., for a training image in shape (3, 32, 100), the feature map encoded by resnet34 bacbkone is in shape (512, 1, 4), W = 4, sequence len is 4.
max_label_len(int): the maximum number of characters in a text label, i.e. max_text_len in yaml.
batch_size(int): batch size of input logits. bs
"""
Expand All @@ -31,27 +32,27 @@ def __init__(self, pred_seq_len=26, max_label_len=25, batch_size=32, reduction='
self.label_indices = Tensor(np.array(label_indices), mstype.int64)
#self.reshape = P.Reshape()
self.ctc_loss = ops.CTCLoss(ctc_merge_repeated=True)

self.reduction = reduction
print('D: ', self.label_indices.shape)

# TODO: diff from paddle, paddle takes `label_length` as input too.
def construct(self, pred, label):
def construct(self, pred: Tensor, label: Tensor):
'''
Args:
pred (dict): {head_out: logits}
logits is a Tensor in shape (W, BS, NC), where W - seq len, BS - batch size. NC - num of classes (types of character + blank + 1)
pred (Tensor): network prediction which is a
logit Tensor in shape (W, BS, NC), where W - seq len, BS - batch size. NC - num of classes (types of character + blank + 1)
label (Tensor): GT sequence of character indices in shape (BS, SL), SL - sequence length, which is padded to max_text_length
Returns:
loss value
loss value (Tensor)
'''
logit = pred['head_out']
#T, bs, nc = logit.shape
logit = pred
#T, bs, nc = logit.shape
#logit = ops.reshape(logit, (T*bs, nc))
label_values = ops.reshape(label, (-1,))

loss, _ = self.ctc_loss(logit, self.label_indices, label_values, self.sequence_length)

if self.reduction=='mean':
loss = loss.mean()

Expand All @@ -65,9 +66,9 @@ def construct(self, pred, label):
pred_seq_len = 24

loss_fn = CTCLoss(pred_seq_len, max_text_length, bs)

x = ms.Tensor(np.random.rand(pred_seq_len, bs, nc), dtype=ms.float32)
label = ms.Tensor(np.random.randint(0, nc, size=(bs, max_text_length)), dtype=ms.int32)

loss = loss_fn({'head_out': x}, label)
print(loss)
loss = loss_fn(x, label)
print(loss)
4 changes: 4 additions & 0 deletions mindocr/metrics/det_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def _get_iou(pd, pg):


class DetectionIoUEvaluator:
'''
'''
def __init__(self, min_iou=0.5, min_intersect=0.5):
self._min_iou = min_iou
self._min_intersect = min_intersect
Expand Down Expand Up @@ -77,6 +79,8 @@ def __call__(self, gt: List[dict], preds: List[np.ndarray]):


class DetMetric(nn.Metric):
'''
'''
def __init__(self, device_num=1, **kwargs):
super().__init__()
self._evaluator = DetectionIoUEvaluator()
Expand Down
1 change: 0 additions & 1 deletion mindocr/metrics/rec_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Length of inputs should be 2')
preds, gt = inputs

pred_texts = preds['texts']
#pred_confs = preds['confs']
#print('pred: ', pred_texts, len(pred_texts))
Expand Down
2 changes: 1 addition & 1 deletion mindocr/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ python tests/ut/test_model.py --config /path/to/yaml_config_file
* Class naming: **{HeadName}** e.g. `class DBHead`
* Class `__init__` args: MUST contain `in_channels` param as the first position, e.g. `__init__(self, in_channels, out_channels=2, **kwargs)`.
* Class `construct` args: feature (Tensor)
* Class `construct` return: prediction Union[Tensor, dict]. If it is a dict, key names should match the used key in loss function. {'maps': out, 'score': score}, which should match the loss function.
* Class `construct` return: prediction (Union(Tensor, Tuple[Tensor])). If there is only one output, return Tensor. If there are multiple outputs, return Tuple of Tensor, e.g., `return output1, output2, output_N`. Note that the order should match the loss function or the postprocess function.


**Note:** if there is no neck in the model architecture like crnn, you can skip writing for neck. `BaseModel` will select the last feature of the features (List(Tensor)) output by Backbone, and forward it Head module.
Expand Down
27 changes: 7 additions & 20 deletions mindocr/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class BaseModel(nn.Cell):
def __init__(self, config: dict):
"""
Args:
config (dict): model config
config (dict): model config
"""
super(BaseModel, self).__init__()

config = Dict(config)
config = Dict(config)
backbone_name = config.backbone.pop('name')
self.backbone = build_backbone(backbone_name, **config.backbone)

Expand All @@ -31,7 +31,7 @@ def __init__(self, config: dict):
head_name = config.head.pop('name')
self.head = build_head(head_name, in_channels=self.neck.out_channels, **config.head)

self.model_name = f'{backbone_name}_{neck_name}_{head_name}'
self.model_name = f'{backbone_name}_{neck_name}_{head_name}'

def construct(self, x):
# TODO: return bout, hout for debugging, using a dict.
Expand All @@ -41,30 +41,17 @@ def construct(self, x):

hout = self.head(nout)

# resize back for postprocess
# resize back for postprocess
#y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)

# for multi head, save ctc neck out for udml
'''
if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x


'''

return hout


if __name__=='__main__':
model_config = {
"backbone": {
'name': 'det_resnet50',
'pretrained': False
'pretrained': False
},
"neck": {
"name": 'FPN',
Expand All @@ -75,10 +62,10 @@ def construct(self, x):
"out_channels": 2,
"k": 50
}

}
model_config.pop('neck')
model = BaseModel(model_config)
model = BaseModel(model_config)

import mindspore as ms
import time
Expand Down
2 changes: 1 addition & 1 deletion mindocr/models/heads/conv_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ def __init__(self, in_channels, out_channels,**kwargs):
)

def construct(self, x):
return {'map': self.conv(x)}
return self.conv(x)

26 changes: 15 additions & 11 deletions mindocr/models/heads/det_db_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple, Union
import mindspore.nn as nn
import mindspore as ms


class DBHead(nn.Cell):
Expand Down Expand Up @@ -28,21 +30,23 @@ def _init_heatmap(in_channels, inter_channels, weight_init, bias):
nn.Sigmoid()
])

def construct(self, features):
def construct(self, features: ms.Tensor) -> Union[ms.Tensor, Tuple[ms.Tensor]]:
'''
Args:
features (Tensor): features output by backbone
features (Tensor): encoded features
Returns:
pred (dict):
- binary: predicted binary map
- thresh: predicted threshold map, only used if adaptive is True in training
- thres_binary: differentiable binary map, only if adaptive is True in training
Union(
binary: predicted binary map
thresh: predicted threshold map (only return if adaptive is True in training)
thres_binary: differentiable binary map (only if adaptive is True in training)
'''
pred = {'binary': self.segm(features)}
binary = self.segm(features)

if self.adaptive and self.training:
# only use binary map to derive polygons in inference
pred['thresh'] = self.thresh(features)
pred['thresh_binary'] = self.diff_bin(
self.k * (pred['binary'] - pred['thresh'])) # Differentiable Binarization
return pred
thresh = self.thresh(features)
thresh_binary = self.diff_bin(self.k * binary - thresh) # Differentiable Binarization
return binary, thresh, thresh_binary
else:
return binary

7 changes: 3 additions & 4 deletions mindocr/models/heads/rec_ctc_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def __init__(self,
self.dense2 = nn.Dense(mid_channels, out_channels, weight_init=weight_init, bias_init=bias_init)
#self.dropout = nn.Dropout(keep_prob)

def construct(self, x):
def construct(self, x: ms.Tensor) -> ms.Tensor:
"""Feed Forward construct.
Args:
x (Tensor): feature in shape [W, BS, 2*C]
Returns:
h (Tensor): if training, h is logits in shape [W, BS, num_classes], where W - sequence len, fixed. (dim order required by ms.ctcloss)
if not training, h is probabilites in shape [BS, W, num_classes].
if not training, h is class probabilites in shape [BS, W, num_classes].
"""
h = self.dense1(x)
#x = self.dropout(x)
Expand All @@ -80,8 +80,7 @@ def construct(self, x):
h = ops.Softmax(axis=2)(h)
h = h.transpose((1, 0, 2))

pred = {'head_out': h}
return pred
return h


if __name__ == '__main__':
Expand Down
22 changes: 13 additions & 9 deletions mindocr/postprocess/det_postprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Tuple, Union
import cv2
import numpy as np
from shapely.geometry import Polygon
import mindspore as ms
from mindspore import Tensor

from ..data.transforms.det_transforms import expand_poly

Expand All @@ -21,27 +24,28 @@ def __init__(self, binary_thresh=0.3, box_thresh=0.7, max_candidates=1000, expan

def __call__(self, pred):
"""
pred:
binary: text region segmentation map, with shape (N, H, W)
thresh: [if exists] threshold prediction with shape (N, H, W)
thresh_binary: [if exists] binarized with threshold, (N, H, W)
pred (Union[Tensor, Tuple[Tensor], np.ndarray]):
binary: text region segmentation map, with shape (N, 1, H, W)
thresh: [if exists] threshold prediction with shape (N, 1, H, W) (optional)
thresh_binary: [if exists] binarized with threshold, (N, 1, H, W) (optional)
Returns:
result (dict) with keys:
polygons: np.ndarray of shape (N, K, 4, 2) for the polygons of objective regions if region_type is 'quad'
scores: np.ndarray of shape (N, K), score for each box
"""
if isinstance(pred, dict):
pred = pred[self._name]
elif isinstance(pred, tuple):
if isinstance(pred, tuple):
pred = pred[self._names[self._name]]
pred = pred.squeeze(1).asnumpy()

if isinstance(pred, Tensor):
pred = pred.asnumpy()
pred = pred.squeeze(1)

segmentation = pred >= self._binary_thresh

# FIXME: dest_size is supposed to be the original image shape (pred.shape -> batch['shape'])
dest_size = np.array(pred.shape[:0:-1]) # w, h order
scale = dest_size / np.array(pred.shape[:0:-1])

# TODO:
# FIXME: output as dict, keep consistent return format to recognition
return [self._extract_preds(pr, segm, scale, dest_size) for pr, segm in zip(pred, segmentation)]

Expand Down
Loading