Skip to content

Add API docstring for psenet #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions mindocr/data/transforms/det_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,35 @@ def expand_poly(poly, distance: float, joint_type=pyclipper.JT_ROUND) -> List[li


class PSEGtDecode:
"""
PSENet transformation which shrinks text polygons.

Args:
kernel_num (int): The number of kernels.
min_shrink_ratio (float): The minimum shrink ratio.
min_shortest_edge (int): The minimum shortest edge.

Returns:
dict: A dictionary containing shrinked image data, polygons, ground truth kernels, ground truth text and masks.
"""

def __init__(self, kernel_num=7, min_shrink_ratio=0.4, min_shortest_edge=640, **kwargs):
self.kernel_num = kernel_num
self.min_shrink_ratio = min_shrink_ratio
self.min_shortest_edge = min_shortest_edge

def _shrink(self, text_polys, rate, max_shr=20):
"""
Shrink text polygons.

Args:
text_polys (list): A list of text polygons.
rate (float): The shrink rate.
max_shr (int): The maximum shrink.

Returns:
list: A list of shrinked text polygons.
"""
rate = rate * rate
shrinked_text_polys = []
for bbox in text_polys:
Expand All @@ -604,6 +627,14 @@ def _shrink(self, text_polys, rate, max_shr=20):
return shrinked_text_polys

def __call__(self, data):
"""
Args:
data (dict): A dictionary containing image data.

Returns:
dict: dict: A dictionary containing shrinked image data, polygons,
ground truth kernels, ground truth text and masks.
"""
image = data["image"]
text_polys = data["polys"]
ignore_tags = data["ignore_tags"]
Expand Down
51 changes: 38 additions & 13 deletions mindocr/losses/det_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,19 @@ def construct(self, pred: Tensor, gt: Tensor, mask: Tensor) -> Tensor:


class PSEDiceLoss(nn.Cell):
"""
PSE Dice Loss module for text detection.

This module calculates the Dice loss between the predicted binary segmentation map and the ground truth map.

Args:
alpha (float): The weight for text loss. Default is 0.7.
ohem_ratio (int): The ratio for hard negative example mining. Default is 3.

Returns:
Tensor: The computed loss value.
"""

def __init__(self, alpha=0.7, ohem_ratio=3):
super().__init__()
self.threshold0 = Tensor(0.5, mstype.float32)
Expand Down Expand Up @@ -207,11 +220,15 @@ def __init__(self, alpha=0.7, ohem_ratio=3):

def ohem_batch(self, scores, gt_texts, training_masks):
"""
Perform online hard example mining (OHEM) for a batch of scores, ground truth texts, and training masks.

:param scores: [N * H * W]
:param gt_texts: [N * H * W]
:param training_masks: [N * H * W]
:return: [N * H * W]
Args:
scores (Tensor): The predicted scores of shape [N * H * W].
gt_texts (Tensor): The ground truth texts of shape [N * H * W].
training_masks (Tensor): The training masks of shape [N * H * W].

Returns:
Tensor: The selected masks of shape [N * H * W].
"""
batch_size = scores.shape[0]
h, w = scores.shape[1:]
Expand Down Expand Up @@ -264,11 +281,15 @@ def ohem_single(self, score, gt_text, training_mask):

def dice_loss(self, input_params, target, mask):
"""
Compute the dice loss between input parameters, target, and mask.

Args:
input_params (Tensor): The input parameters of shape [N, H, W].
target (Tensor): The target of shape [N, H, W].
mask (Tensor): The mask of shape [N, H, W].

:param input: [N, H, W]
:param target: [N, H, W]
:param mask: [N, H, W]
:return:
Returns:
Tensor: The dice loss value.
"""
batch_size = input_params.shape[0]
input_sigmoid = self.sigmoid(input_params)
Expand Down Expand Up @@ -296,12 +317,16 @@ def avg_losses(self, loss_list):

def construct(self, model_predict, gt_texts, gt_kernels, training_masks):
"""
Construct the PSE Dice Loss calculation.

Args:
model_predict (Tensor): The predicted model outputs of shape [N * 7 * H * W].
gt_texts (Tensor): The ground truth texts of shape [N * H * W].
gt_kernels (Tensor): The ground truth kernels of shape [N * 6 * H * W].
training_masks (Tensor): The training masks of shape [N * H * W].

:param model_predict: [N * 7 * H * W]
:param gt_texts: [N * H * W]
:param gt_kernels:[N * 6 * H * W]
:param training_masks:[N * H * W]
:return:
Returns:
Tensor: The computed loss value.
"""
batch_size = model_predict.shape[0]
model_predict = self.upsample(model_predict, scale_factor=4)
Expand Down
17 changes: 15 additions & 2 deletions mindocr/models/heads/det_pse_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@ def _bn(channels, momentum=0.1):


class PSEHead(nn.Cell):
"""
PSE Head module for text detection.

This module takes a single input feature map and applies convolutional operations
to generate the output feature map.

Args:
in_channels (int): The feature dimension of a single feature map generated by the neck (FPN).
hidden_size (int): The hidden size for intermediate convolutions.
out_channels (int): The output channel size.

Returns:
Tensor: The output feature map of shape [batch_size, out_channels, H, W].

"""
def __init__(self, in_channels: int, hidden_size: int, out_channels: int):
# in_channels is the feature dimension of a single feature map generated by the neck (fpn).
# BUT the fpn generates 4 feature maps, so the total number of channels is 4 * in_channels.
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
Expand Down
15 changes: 15 additions & 0 deletions mindocr/models/necks/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ def _bn(channels, momentum=0.1):


class PSEFPN(nn.Cell):
"""
PSE Feature Pyramid Network (FPN) module for text detection.

This module takes multiple input feature maps and performs feature fusion
and upsampling to generate a single output feature map.

Args:
in_channels (List[int]): The input channel dimensions for each feature map
in the following order: [c2, c3, c4, c5].
out_channels (int): The output channel size.

Returns:
Tensor: The output feature map of shape [batch_size, out_channels * 4, H, W].

"""
def __init__(self, in_channels: List[int], out_channels):
super().__init__()
super(PSEFPN, self).__init__()
Expand Down
18 changes: 18 additions & 0 deletions mindocr/postprocess/det_pse_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@


class PSEPostprocess(DetBasePostprocess):
"""
Post-processing module for PSENet text detection.

This module takes the network predictions and performs post-processing to obtain the final text detection results.

Args:
binary_thresh (float): The threshold value for binarization. Default is 0.5.
box_thresh (float): The threshold value for generating bounding boxes. Default is 0.85.
min_area (int): The minimum area threshold for filtering small text regions. Default is 16.
box_type (str): The type of bounding boxes to generate. Can be "quad" or "poly". Default is "quad".
scale (int): The scale factor for resizing the predicted output. Default is 4.
output_score_kernels (bool): Whether to output the scores and kernels. Default is False.
rescale_fields (list): The list of fields to be rescaled. Default is ["polys"].

Returns:
dict: A dictionary containing the final text detection results.
"""

def __init__(
self,
binary_thresh=0.5,
Expand Down