Skip to content

Commit 78aa92e

Browse files
committed
Divide en->th and th->en translations into two separated classes
1 parent 0859ec8 commit 78aa92e

File tree

5 files changed

+89
-87
lines changed

5 files changed

+89
-87
lines changed

docs/api/tools.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
pythainlp.tools
44
====================================
5-
The :class:`pythainlp.tools` is tool for pythainlp.
5+
The :class:`pythainlp.tools` contains miscellaneous functions for PyThaiNLP internal use.
66

77
Modules
88
-------
99

1010
.. autofunction:: get_full_data_path
1111
.. autofunction:: get_pythainlp_data_path
12-
.. autofunction:: get_pythainlp_path
12+
.. autofunction:: get_pythainlp_path

docs/api/translate.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,7 @@ Modules
88
-------
99

1010
.. autofunction:: download_model_all
11-
.. autoclass:: Translate
12-
:members: translate
11+
.. autoclass:: EnThTranslate
12+
:members: translate
13+
.. autoclass:: ThEnTranslate
14+
:members: translate

pythainlp/translate/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
Language translation.
44
"""
55

6-
__all__ = [
7-
"Translate",
8-
"download_model_all"
9-
]
6+
__all__ = ["EnThTranslator", "ThEnTranslator", "download_model_all"]
107

11-
from pythainlp.translate.core import Translate, download_model_all
8+
from pythainlp.translate.core import (
9+
EnThTranslator,
10+
ThEnTranslator,
11+
download_model_all,
12+
)

pythainlp/translate/core.py

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@
99
from fairseq.models.transformer import TransformerModel
1010
from sacremoses import MosesTokenizer
1111

12-
_en_tokenizer = MosesTokenizer("en")
1312

13+
_EN_TH_MODEL_NAME = "scb_1m_en-th_moses"
1414
# SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0.tar.gz
15-
_EN_TH_FILE_NAME = (
16-
"SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0"
17-
)
15+
_EN_TH_FILE_NAME = "SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0"
16+
17+
_TH_EN_MODEL_NAME = "scb_1m_th-en_spm"
1818
# SCB_1M-MT_OPUS+TBASE_th-en_spm-spm_32000-joined_v1.0.tar.gz
1919
_TH_EN_FILE_NAME = "SCB_1M-MT_OPUS+TBASE_th-en_spm-spm_32000-joined_v1.0"
2020

2121

22+
def _get_translate_path(model: str, *path: str) -> str:
23+
return os.path.join(get_full_data_path(model), *path)
24+
25+
2226
def _download_install(name: str) -> None:
2327
if get_corpus_path(name) is None:
2428
download(name, force=True, version="1.0")
@@ -33,81 +37,79 @@ def _download_install(name: str) -> None:
3337

3438
def download_model_all() -> None:
3539
"""
36-
Download Model
40+
Download all translation models in advanced
3741
"""
38-
_download_install("scb_1m_th-en_spm")
39-
_download_install("scb_1m_en-th_moses")
42+
_download_install(_EN_TH_MODEL_NAME)
43+
_download_install(_TH_EN_MODEL_NAME)
4044

4145

42-
def _get_translate_path(model: str, *path: str) -> str:
43-
return os.path.join(get_full_data_path(model), *path)
44-
45-
46-
class Translate:
46+
class EnThTranslator:
4747
def __init__(self):
48-
self._model = None
49-
self._model_name = None
50-
51-
def _scb_en_th_model_init(self):
52-
if self._model_name != "scb_1m_en-th_moses":
53-
self._model_name = "scb_1m_en-th_moses"
54-
_download_install(self._model_name)
55-
self._model = TransformerModel.from_pretrained(
56-
model_name_or_path=_get_translate_path(
57-
self._model_name, _EN_TH_FILE_NAME, "models",
58-
),
59-
checkpoint_file="checkpoint.pt",
60-
data_name_or_path=_get_translate_path(
61-
self._model_name, _EN_TH_FILE_NAME, "vocab",
62-
),
63-
)
64-
65-
def _scb_en_th_translate(self, text: str) -> str:
66-
self._scb_en_th_model_init()
67-
tokens = " ".join(_en_tokenizer.tokenize(text))
48+
self._tokenizer = MosesTokenizer("en")
49+
50+
self._model_name = _EN_TH_MODEL_NAME
51+
52+
_download_install(self._model_name)
53+
self._model = TransformerModel.from_pretrained(
54+
model_name_or_path=_get_translate_path(
55+
self._model_name,
56+
_EN_TH_FILE_NAME,
57+
"models",
58+
),
59+
checkpoint_file="checkpoint.pt",
60+
data_name_or_path=_get_translate_path(
61+
self._model_name,
62+
_EN_TH_FILE_NAME,
63+
"vocab",
64+
),
65+
)
66+
67+
def translate(self, text: str) -> str:
68+
"""
69+
Translate text from English to Thai
70+
71+
:param str text: input text in source language
72+
:return: translated text in target language
73+
:rtype: str
74+
"""
75+
tokens = " ".join(self._tokenizer.tokenize(text))
6876
translated = self._model.translate(tokens)
69-
return translated.replace(' ', '').replace('▁', ' ').strip()
70-
71-
def _scb_th_en_model_init(self):
72-
if self._model_name != "scb_1m_th-en_spm":
73-
self._model_name = "scb_1m_th-en_spm"
74-
_download_install(self._model_name)
75-
self._model = TransformerModel.from_pretrained(
76-
model_name_or_path=_get_translate_path(
77-
self._model_name, _TH_EN_FILE_NAME, "models",
78-
),
79-
checkpoint_file="checkpoint.pt",
80-
data_name_or_path=_get_translate_path(
81-
self._model_name, _TH_EN_FILE_NAME, "vocab",
82-
),
83-
bpe="sentencepiece",
84-
sentencepiece_model=_get_translate_path(
85-
self._model_name, _TH_EN_FILE_NAME, "bpe", "spm.th.model",
86-
),
87-
)
88-
89-
def _scb_th_en_translate(self, text: str) -> str:
90-
self._scb_th_en_model_init()
91-
return self._model.translate(text)
77+
return translated.replace(" ", "").replace("▁", " ").strip()
78+
9279

93-
def translate(self, text: str, source: str, target: str) -> str:
80+
class ThEnTranslator:
81+
def __init__(self):
82+
self._model_name = _TH_EN_MODEL_NAME
83+
84+
_download_install(self._model_name)
85+
self._model = TransformerModel.from_pretrained(
86+
model_name_or_path=_get_translate_path(
87+
self._model_name,
88+
_TH_EN_FILE_NAME,
89+
"models",
90+
),
91+
checkpoint_file="checkpoint.pt",
92+
data_name_or_path=_get_translate_path(
93+
self._model_name,
94+
_TH_EN_FILE_NAME,
95+
"vocab",
96+
),
97+
bpe="sentencepiece",
98+
sentencepiece_model=_get_translate_path(
99+
self._model_name,
100+
_TH_EN_FILE_NAME,
101+
"bpe",
102+
"spm.th.model",
103+
),
104+
)
105+
106+
def translate(self, text: str) -> str:
94107
"""
95-
Translate Language
108+
Translate text from Thai to English
96109
97110
:param str text: input text in source language
98-
:param str source: source language ("en" or "th")
99-
:param str target: target language ("en" or "th")
100-
101111
:return: translated text in target language
102112
:rtype: str
103113
"""
104-
translated = None
105-
if source == "th" and target == "en":
106-
translated = self._scb_th_en_translate(text)
107-
elif source == "en" and target == "th":
108-
translated = self._scb_en_th_translate(text)
109-
else:
110-
return ValueError(
111-
"The combination of the arguments isn't allowed."
112-
)
113-
return translated
114+
self._model_init()
115+
return self._model.translate(text)

tests/test_translate.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,20 @@
22

33
import unittest
44

5-
from pythainlp.translate import Translate
5+
from pythainlp.translate import EnThTranslator, ThEnTranslator
66

77

88
class TestTranslatePackage(unittest.TestCase):
99
def test_translate(self):
10-
self.translate = Translate()
10+
self.th_en_translator = ThEnTranslator()
1111
self.assertIsNotNone(
12-
self.translate.translate(
12+
self.th_en_translator.translate(
1313
"แมวกินปลา",
14-
source="th",
15-
target="en"
1614
)
1715
)
16+
self.en_th_translator = EnThTranslator()
1817
self.assertIsNotNone(
19-
self.translate.translate(
18+
self.en_th_translator.translate(
2019
"the cat eats fish.",
21-
source="en",
22-
target="th"
2320
)
2421
)

0 commit comments

Comments
 (0)