Skip to content

[WIP] pr592 适配合入 && det后处理 长短框合并 #682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
09e9aee
Fix problems of psenet-ctw1500 training
horcham Jan 31, 2024
828fbb7
Fix problems of psenet-ctw1500 training
horcham Jan 31, 2024
cbaea06
Fix problems of psenet-ctw1500 training
horcham Jan 31, 2024
e6b64e2
pull upstream
horcham Feb 5, 2024
e074f5f
Merge branch 'main' of https://github.com/mindspore-lab/mindocr
horcham Mar 6, 2024
e38bacb
for export
horcham Mar 18, 2024
77d8138
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
8c1938b
fix large npu memory cost
horcham Mar 18, 2024
d5e15db
Merge pull request #2 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
b20f0fa
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
7d10934
Merge pull request #3 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
56dab20
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
9ce69b8
Merge branch 'infer' into dev-offlineinfer
Bourn3z Mar 18, 2024
66850a4
Merge pull request #4 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
855bade
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
4626519
Merge branch 'infer' into dev-offlineinfer
Bourn3z Mar 18, 2024
66c8509
Merge pull request #7 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
2750581
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
04976fd
Merge branch 'infer' into dev-offlineinfer
Bourn3z Mar 18, 2024
f8fc62f
Merge pull request #8 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
d654c13
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
f5b05fa
Merge pull request #9 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
258da17
Add the function of concatenating to crops after detection.
Bourn3z Mar 18, 2024
2621f14
Merge branch 'infer' into dev-offlineinfer
Bourn3z Mar 18, 2024
83dfa9e
Merge pull request #10 from Bourn3z/dev-offlineinfer
Bourn3z Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deploy/py_infer/src/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def warmup(self):
height, width = hw_list[0]
warmup_shape = [(*other_shape, height, width)] # Only single input

dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)]
self.model.infer(dummy_tensor)
# dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)]
# self.model.infer(dummy_tensor)

def __del__(self):
if hasattr(self, "model") and self.model:
Expand Down
1 change: 1 addition & 0 deletions deploy/py_infer/src/data_process/postprocess/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_device_status():
def _get_status():
nonlocal status
try:
ms.set_context(max_device_memory="0.01GB")
status = ms.Tensor([0])[0:].asnumpy()[0]
except RuntimeError:
status = 1
Expand Down
3 changes: 3 additions & 0 deletions deploy/py_infer/src/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def get_args():
"--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring."
)
parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.")
parser.add_argument(
"--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection."
)

args = parser.parse_args()
setup_logger(args)
Expand Down
28 changes: 28 additions & 0 deletions deploy/py_infer/src/parallel/module/detection/det_post_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cv2
import numpy as np

from ....data_process.utils import cv_utils
Expand All @@ -10,19 +11,44 @@ def __init__(self, args, msg_queue):
super(DetPostNode, self).__init__(args, msg_queue)
self.text_detector = None
self.task_type = self.args.task_type
self.is_concat = self.args.is_concat

def init_self_args(self):
self.text_detector = TextDetector(self.args)
self.text_detector.init(preprocess=False, model=False, postprocess=True)
super().init_self_args()

def concat_crops(self, crops: list):
"""
Concatenates the list of cropped images horizontally after resizing them to have the same height.

Args:
crops (list): A list of cropped images represented as numpy arrays.

Returns:
numpy.ndarray: A horizontally concatenated image array.
"""
max_height = max(crop.shape[0] for crop in crops)
resized_crops = []
for crop in crops:
h, w, c = crop.shape
new_h = max_height
new_w = int((w / h) * new_h)

resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
resized_crops.append(resized_img)
crops_concated = np.concatenate(resized_crops, axis=1)
return crops_concated

def process(self, input_data):
if input_data.skip:
self.send_to_next_module(input_data)
return

data = input_data.data
boxes = self.text_detector.postprocess(data["pred"], data["shape_list"])
if self.is_concat:
boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0]))

infer_res_list = []
for box in boxes:
Expand All @@ -39,6 +65,8 @@ def process(self, input_data):
for box in infer_res_list:
sub_image = cv_utils.crop_box_from_image(image, np.array(box))
sub_image_list.append(sub_image)
if self.is_concat:
sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)]
input_data.sub_image_list = sub_image_list

input_data.data = None
Expand Down
11 changes: 10 additions & 1 deletion mindocr/losses/det_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from math import pi
from typing import Tuple, Union

Expand All @@ -10,6 +11,8 @@
__all__ = ["DBLoss", "PSEDiceLoss", "EASTLoss", "FCELoss"]
_logger = logging.getLogger(__name__)

OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)


class DBLoss(nn.LossBase):
"""
Expand Down Expand Up @@ -165,7 +168,13 @@ def construct(self, pred: Tensor, gt: Tensor, mask: Tensor) -> Tensor:
neg_loss = (loss * negative).view(loss.shape[0], -1)

neg_vals, _ = ops.sort(neg_loss)
neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1)

if OFFLINE_MODE is None:
neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1)
else:
neg_index = ops.stack(
(ops.arange(loss.shape[0], dtype=neg_count.dtype), neg_vals.shape[1] - neg_count), axis=1
)
min_neg_score = ops.expand_dims(ops.gather_nd(neg_vals, neg_index), axis=1)

neg_loss_mask = (neg_loss >= min_neg_score).astype(ms.float32) # filter values less than top k
Expand Down
15 changes: 13 additions & 2 deletions mindocr/losses/rec_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np

import mindspore as ms
Expand All @@ -6,6 +8,8 @@

__all__ = ["CTCLoss", "AttentionLoss", "VisionLANLoss"]

OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)


class CTCLoss(LossBase):
"""
Expand Down Expand Up @@ -147,14 +151,21 @@ class AttentionLoss(LossBase):
def __init__(self, reduction: str = "mean", ignore_index: int = 0) -> None:
super().__init__()
# ignore <GO> symbol, assume it is placed at 0th index
self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index)
if OFFLINE_MODE is None:
self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index)
else:
self.reduction = reduction
self.ignore_index = ignore_index

def construct(self, logits: Tensor, labels: Tensor) -> Tensor:
labels = labels[:, 1:] # without <GO> symbol
num_classes = logits.shape[-1]
logits = ops.reshape(logits, (-1, num_classes))
labels = ops.reshape(labels, (-1,))
return self.criterion(logits, labels)
if OFFLINE_MODE is None:
return self.criterion(logits, labels)
else:
return ops.cross_entropy(logits, labels, reduction=self.reduction, ignore_index=self.ignore_index)


class SARLoss(LossBase):
Expand Down
34 changes: 24 additions & 10 deletions mindocr/models/necks/fpn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List, Tuple

from mindspore import Tensor, nn, ops
Expand All @@ -7,14 +8,20 @@
from ..utils.attention_cells import SEModule
from .asf import AdaptiveScaleFusion

OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)

def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None):
if scale == 1 or shape == x.shape[2:]:
return x

if scale:
shape = (x.shape[2] * scale, x.shape[3] * scale)
return ops.ResizeNearestNeighbor(shape)(x)
if OFFLINE_MODE is None:
def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None):
if scale == 1 or shape == x.shape[2:]:
return x

if scale:
shape = (x.shape[2] * scale, x.shape[3] * scale)
return ops.ResizeNearestNeighbor(shape)(x)
else:
def _resize_nn(x: Tensor, shape: Tensor):
return ops.ResizeNearestNeighborV2()(x, shape)


class FPN(nn.Cell):
Expand Down Expand Up @@ -64,11 +71,18 @@ def construct(self, features: List[Tensor]) -> Tensor:
for i, uc_op in enumerate(self.unify_channels):
features[i] = uc_op(features[i])

for i in range(2, -1, -1):
features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:])
if OFFLINE_MODE is None:
for i in range(2, -1, -1):
features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:])

for i, out in enumerate(self.out):
features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:])
else:
for i in range(2, -1, -1):
features[i] += _resize_nn(features[i + 1], shape=ops.dyn_shape(features[i])[2:])

for i, out in enumerate(self.out):
features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:])
for i, out in enumerate(self.out):
features[i] = _resize_nn(out(features[i]), shape=ops.dyn_shape(features[0])[2:])

return self.fuse(features[::-1]) # matching the reverse order of the original work

Expand Down
14 changes: 12 additions & 2 deletions mindocr/models/transforms/tps_spatial_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import os
from typing import Optional, Tuple

import numpy as np
Expand All @@ -8,6 +9,8 @@
import mindspore.ops as ops
from mindspore import Tensor

OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)


def grid_sample(input: Tensor, grid: Tensor, canvas: Optional[Tensor] = None) -> Tensor:
output = ops.grid_sample(input, grid)
Expand Down Expand Up @@ -111,15 +114,22 @@ def __init__(
self.target_coordinate_repr = Tensor(target_coordinate_repr, dtype=ms.float32)
self.target_control_points = Tensor(target_control_points, dtype=ms.float32)

if OFFLINE_MODE is not None:
self.matmul = ops.BatchMatMul()

def construct(
self, input: Tensor, source_control_points: Tensor
) -> Tuple[Tensor, Tensor]:
batch_size = ops.shape(source_control_points)[0]

padding_matrix = ops.tile(self.padding_matrix, (batch_size, 1, 1))
Y = ops.concat([source_control_points, padding_matrix], axis=1)
mapping_matrix = ops.matmul(self.inverse_kernel, Y)
source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix)
if OFFLINE_MODE is None:
mapping_matrix = ops.matmul(self.inverse_kernel, Y)
source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix)
else:
mapping_matrix = self.matmul(self.inverse_kernel[None, ...], Y)
source_coordinate = self.matmul(self.target_coordinate_repr[None, ...], mapping_matrix)
grid = ops.reshape(
source_coordinate,
(-1, self.target_height, self.target_width, 2),
Expand Down
14 changes: 11 additions & 3 deletions mindocr/models/utils/attention_cells.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional, Tuple

import numpy as np
Expand All @@ -9,6 +10,8 @@

__all__ = ["MultiHeadAttention", "PositionwiseFeedForward", "PositionalEncoding", "SEModule"]

OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)


class MultiHeadAttention(nn.Cell):
def __init__(
Expand Down Expand Up @@ -108,9 +111,14 @@ def __init__(
self.pe = Tensor(pe, dtype=ms.float32)

def construct(self, input_tensor: Tensor) -> Tensor:
input_tensor = (
input_tensor + self.pe[:, : input_tensor.shape[1]]
) # pe 1 5000 512
if OFFLINE_MODE is None:
input_tensor = (
input_tensor + self.pe[:, : input_tensor.shape[1]]
) # pe 1 5000 512
else:
input_tensor = (
input_tensor + self.pe[:, : ops.dyn_shape(input_tensor)[1]]
) # pe 1 5000 512
return self.dropout(input_tensor)


Expand Down