Skip to content

Commit

Permalink
refactor: Avoids instantiating orientation predictor when unnecessary (
Browse files Browse the repository at this point in the history
…mindee#809)

* refactor: Avoids instantiating crop orientation predictor when unnecessary

* test: Updated unittests

* refactor: Refactored predictor args
  • Loading branch information
fg-mindee authored Jan 18, 2022
1 parent f0eae11 commit 4da0557
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 33 deletions.
27 changes: 20 additions & 7 deletions doctr/models/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import List, Tuple
from typing import Any, List, Optional, Tuple

import numpy as np

from doctr.models.builder import DocumentBuilder

from .._utils import extract_crops, extract_rcrops, rectify_crops, rectify_loc_preds
from ..classification import crop_orientation_predictor
from ..classification.predictor import CropOrientationPredictor

__all__ = ['_OCRPredictor']

Expand All @@ -19,14 +20,26 @@ class _OCRPredictor:
"""Implements an object able to localize and identify text elements in a set of documents
Args:
det_predictor: detection module
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
kwargs: keyword args of `DocumentBuilder`
"""

doc_builder: DocumentBuilder
crop_orientation_predictor: Optional[CropOrientationPredictor]

def __init__(self) -> None:
self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True)
def __init__(
self,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
**kwargs: Any,
) -> None:
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = None if assume_straight_pages else crop_orientation_predictor(pretrained=True)
self.doc_builder = DocumentBuilder(**kwargs)

@staticmethod
def _generate_crops(
Expand Down Expand Up @@ -70,7 +83,7 @@ def _rectify_crops(
loc_preds: List[np.ndarray],
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
# Work at a page level
orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops]
orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops] # type: ignore[misc]
rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)]
rect_loc_preds = [
rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds
Expand Down
15 changes: 4 additions & 11 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@

from doctr.io.elements import Document
from doctr.models._utils import estimate_orientation
from doctr.models.builder import DocumentBuilder
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.utils.geometry import rotate_boxes, rotate_image

from ..classification import crop_orientation_predictor
from .base import _OCRPredictor

__all__ = ['OCRPredictor']
Expand All @@ -30,30 +28,25 @@ class OCRPredictor(nn.Module, _OCRPredictor):
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
kwargs: keyword args of `DocumentBuilder`
"""

def __init__(
self,
det_predictor: DetectionPredictor,
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
export_as_straight_boxes: bool = False,
straighten_pages: bool = False,
**kwargs: Any,
) -> None:

super().__init__()
nn.Module.__init__(self)
self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
self.doc_builder = DocumentBuilder(export_as_straight_boxes=export_as_straight_boxes)
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True)
_OCRPredictor.__init__(self, assume_straight_pages, straighten_pages, **kwargs)

@torch.no_grad()
def forward(
Expand Down
15 changes: 4 additions & 11 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@

from doctr.io.elements import Document
from doctr.models._utils import estimate_orientation
from doctr.models.builder import DocumentBuilder
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.utils.geometry import rotate_boxes, rotate_image
from doctr.utils.repr import NestedObject

from ..classification import crop_orientation_predictor
from .base import _OCRPredictor

__all__ = ['OCRPredictor']
Expand All @@ -30,30 +28,25 @@ class OCRPredictor(NestedObject, _OCRPredictor):
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
kwargs: keyword args of `DocumentBuilder`
"""
_children_names = ['det_predictor', 'reco_predictor']
_children_names = ['det_predictor', 'reco_predictor', 'doc_builder']

def __init__(
self,
det_predictor: DetectionPredictor,
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
export_as_straight_boxes: bool = False,
straighten_pages: bool = False,
**kwargs: Any,
) -> None:

super().__init__()
self.det_predictor = det_predictor
self.reco_predictor = reco_predictor
self.doc_builder = DocumentBuilder(export_as_straight_boxes=export_as_straight_boxes)
self.assume_straight_pages = assume_straight_pages
self.straighten_pages = straighten_pages
self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True)
_OCRPredictor.__init__(self, assume_straight_pages, straighten_pages, **kwargs)

def __call__(
self,
Expand Down
9 changes: 5 additions & 4 deletions doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def ocr_predictor(
reco_arch: str = 'crnn_vgg16_bn',
pretrained: bool = False,
assume_straight_pages: bool = True,
export_as_straight_boxes: bool = False,
preserve_aspect_ratio: bool = False,
export_as_straight_boxes: bool = False,
**kwargs: Any
) -> OCRPredictor:
"""End-to-end OCR architecture using one model for localization, and another for text recognition.
Expand All @@ -67,10 +67,11 @@ def ocr_predictor(
pretrained: If True, returns a model pre-trained on our OCR dataset
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
running the detection model on it.
export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions
(potentially rotated) as straight bounding boxes.
kwargs: keyword args of `OCRPredictor`
Returns:
OCR predictor
Expand All @@ -81,7 +82,7 @@ def ocr_predictor(
reco_arch,
pretrained,
assume_straight_pages=assume_straight_pages,
export_as_straight_boxes=export_as_straight_boxes,
preserve_aspect_ratio=preserve_aspect_ratio,
export_as_straight_boxes=export_as_straight_boxes,
**kwargs,
)
6 changes: 6 additions & 0 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from torch import nn

from doctr import models
from doctr.io import Document, DocumentFile
Expand Down Expand Up @@ -43,6 +44,11 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
straighten_pages=straighten_pages,
)

if assume_straight_pages:
assert predictor.crop_orientation_predictor is None
else:
assert isinstance(predictor.crop_orientation_predictor, nn.Module)

out = predictor(doc)
assert isinstance(out, Document)
assert len(out.pages) == 2
Expand Down
6 changes: 6 additions & 0 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from doctr.models.predictor import OCRPredictor
from doctr.models.preprocessor import PreProcessor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.utils.repr import NestedObject


@pytest.mark.parametrize(
Expand Down Expand Up @@ -45,6 +46,11 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
straighten_pages=straighten_pages,
)

if assume_straight_pages:
assert predictor.crop_orientation_predictor is None
else:
assert isinstance(predictor.crop_orientation_predictor, NestedObject)

out = predictor(doc)
assert isinstance(out, Document)
assert len(out.pages) == 2
Expand Down

0 comments on commit 4da0557

Please sign in to comment.