Skip to content

Commit

Permalink
Merge pull request #107 from breezedeus/dev
Browse files Browse the repository at this point in the history
Bugfixes and New Models: mfr-plus
  • Loading branch information
breezedeus authored May 19, 2024
2 parents cad5eb4 + 0499760 commit dce6dac
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 27 deletions.
15 changes: 15 additions & 0 deletions docs/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Release Notes

## Update 2024.05.19:**V1.1.0.3** Released

Major changes:

* A new paid model, `mfr-plus`, has been added, which offers better recognition for multi-line formulas.
* When recognizing only English, CnOCR does not output Chinese.
* Bugs have been fixed.

主要变更:

* 加入新的付费模型:`mfr-plus`,对多行公式的识别效果更好。
* 在只识别英文时,CnOCR 不输出中文。
* 修复 bugs。


## Update 2024.05.10:**V1.1.0.2** Released

Major changes:
Expand Down
2 changes: 1 addition & 1 deletion pix2text/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# [Pix2Text](https://github.com/breezedeus/pix2text): an Open-Source Alternative to Mathpix.
# Copyright (C) 2022-2024, [Breezedeus](https://www.breezedeus.com).

__version__ = '1.1.0.2'
__version__ = '1.1.0.3'
9 changes: 9 additions & 0 deletions pix2text/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ class AvailableModels(object):
'filename': 'p2t-mfr-pro-pytorch.zip', # download the file from CN OSS
'local_model_id': 'mfr-pro-pytorch',
},
('mfr-plus', 'onnx'): {
'filename': 'p2t-mfr-plus-onnx.zip', # download the file from CN OSS
'hf_model_id': 'breezedeus/pix2text-mfr-plus',
'local_model_id': 'mfr-plus-onnx',
},
('mfr-plus', 'pytorch'): {
'filename': 'p2t-mfr-plus-pytorch.zip', # download the file from CN OSS
'local_model_id': 'mfr-plus-pytorch',
},
}
)

Expand Down
9 changes: 6 additions & 3 deletions pix2text/doc_xl_layout/doc_xl_layout_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,12 @@ def _format_outputs(self, img0, out, table_as_image: bool):
for col_number, col_info in column_meta.items():
overlap_val = x_overlap(_box_info, col_info, key='position')
overlap_vals.append([col_number, overlap_val])
overlap_vals.sort(key=lambda x: (x[1], x[0]), reverse=True)
match_col_number = overlap_vals[0][0]
_box_info['col_number'] = match_col_number
if overlap_vals:
overlap_vals.sort(key=lambda x: (x[1], x[0]), reverse=True)
match_col_number = overlap_vals[0][0]
_box_info['col_number'] = match_col_number
else:
_box_info['col_number'] = 0

return final_out, column_meta

Expand Down
49 changes: 31 additions & 18 deletions pix2text/latex_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
)

from PIL import Image
from cnstd.utils import get_model_file
from cnocr.utils import get_default_ort_providers
from transformers.generation import (
GenerateEncoderDecoderOutput,
GenerateBeamEncoderDecoderOutput,
)

from .consts import MODEL_VERSION, AVAILABLE_MODELS
from .utils import data_dir, select_device, prepare_imgs
Expand Down Expand Up @@ -196,19 +199,7 @@ def _one_batch(self, img_list, rec_config, **kwargs):
output_scores=True,
**rec_config,
)
logits = torch.stack(outs.scores, dim=1)
scores = torch.softmax(logits, dim=-1).max(dim=2).values

mean_probs = []
for idx, example in enumerate(scores):
cur_length = int(
(outs.sequences[idx] != self.processor.tokenizer.pad_token_id).sum()
)
assert cur_length > 1
# 获得几何平均值。注意:example中的第一个元素对应sequence中的第二个元素
mean_probs.append(
float((example[: cur_length - 1] + 1e-8).log().mean().exp())
)
mean_probs = self._cal_scores(outs)

generated_text = self.processor.batch_decode(
outs.sequences, skip_special_tokens=True
Expand All @@ -220,6 +211,28 @@ def _one_batch(self, img_list, rec_config, **kwargs):
final_out.append({'text': text, 'score': prob})
return final_out

def _cal_scores(self, outs):
if isinstance(outs, GenerateBeamEncoderDecoderOutput):
mean_probs = outs.sequences_scores.exp().tolist()
elif isinstance(outs, GenerateEncoderDecoderOutput):
logits = torch.stack(outs.scores, dim=1)
scores = torch.softmax(logits, dim=-1).max(dim=2).values

mean_probs = []
for idx, example in enumerate(scores):
cur_length = int(
(outs.sequences[idx] != self.processor.tokenizer.pad_token_id).sum()
)
assert cur_length > 1
# 获得几何平均值。注意:example中的第一个元素对应sequence中的第二个元素
mean_probs.append(
float((example[: cur_length - 1] + 1e-8).log().mean().exp())
)
else:
raise Exception(f'unprocessed output type: {type(outs)}')

return mean_probs

def post_process(self, text):
text = remove_redundant_script(text)
text = remove_trailing_whitespace(text)
Expand Down Expand Up @@ -247,10 +260,10 @@ def remove_redundant_script(text):

def replace_illegal_symbols(text):
illegal_to_legals = [
(r'\\\.', r'\\ .'), # \. -> \ .
(r'\\=', r'\\ ='), # \= -> \ =
(r'\\-', r'\\ -'), # \- -> \ -
(r'\\~', r'\\ ~'), # \~ -> \ ~
(r'\\\.', r'\\ .'), # \. -> \ .
(r'\\=', r'\\ ='), # \= -> \ =
(r'\\-', r'\\ -'), # \- -> \ -
(r'\\~', r'\\ ~'), # \~ -> \ ~
]
for illegal, legal in illegal_to_legals:
text = re.sub(illegal, legal, text)
Expand Down
3 changes: 3 additions & 0 deletions pix2text/ocr_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
# [Pix2Text](https://github.com/breezedeus/pix2text): an Open-Source Alternative to Mathpix.
# Copyright (C) 2022-2024, [Breezedeus](https://www.breezedeus.com).
import string
from typing import Sequence, List, Optional

import numpy as np
Expand Down Expand Up @@ -177,6 +178,8 @@ def prepare_ocr_engine(languages: Sequence[str], ocr_engine_config):
if len(set(languages).difference({'en', 'ch_sim'})) == 0:
from cnocr import CnOcr

if 'ch_sim' not in languages and 'cand_alphabet' not in ocr_engine_config: # only recognize english characters
ocr_engine_config['cand_alphabet'] = string.printable
ocr_engine = CnOcr(**ocr_engine_config)
engine_wrapper = CnOCREngine(languages, ocr_engine)
else:
Expand Down
18 changes: 13 additions & 5 deletions pix2text/page_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,10 @@ def __repr__(self) -> str:
return f"Page(id={self.id}, number={self.number}, elements={self.elements})"

def to_markdown(
self, out_dir: Union[str, Path], root_url: Optional[str]=None, markdown_fn: Optional[str] = 'output.md'
self,
out_dir: Union[str, Path],
root_url: Optional[str] = None,
markdown_fn: Optional[str] = 'output.md',
) -> str:
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
Expand Down Expand Up @@ -224,7 +227,9 @@ def to_markdown(
f.write(md_out)
return md_out

def _ele_to_markdown(self, element: Element, root_url: Optional[str], out_dir: Union[str, Path]):
def _ele_to_markdown(
self, element: Element, root_url: Optional[str], out_dir: Union[str, Path]
):
type = element.type
text = element.text
if type in (ElementType.TEXT, ElementType.TABLE):
Expand Down Expand Up @@ -283,17 +288,20 @@ def __repr__(self) -> str:
return f"Document(id={self.id}, number={self.number}, pages={self.pages})"

def to_markdown(
self, out_dir: Union[str, Path], markdown_fn: Optional[str] = 'output.md'
self,
out_dir: Union[str, Path],
root_url: Optional[str] = None,
markdown_fn: Optional[str] = 'output.md',
) -> str:
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
self.pages.sort(key=lambda page: page.number)
if not self.pages:
return ''
md_out = self.pages[0].to_markdown(out_dir, markdown_fn=None)
md_out = self.pages[0].to_markdown(out_dir, root_url=root_url, markdown_fn=None)
for idx, page in enumerate(self.pages[1:]):
prev_page = self.pages[idx]
cur_txt = page.to_markdown(out_dir, markdown_fn=None)
cur_txt = page.to_markdown(out_dir, mroot_url=root_url, arkdown_fn=None)
if (
md_out
and prev_page.elements
Expand Down

0 comments on commit dce6dac

Please sign in to comment.