Skip to content

Commit c2e65bc

Browse files
authored
Merge pull request #508 from PyThaiNLP/translate-class
Change Translate to class
2 parents aac2277 + ce547a6 commit c2e65bc

File tree

5 files changed

+76
-86
lines changed

5 files changed

+76
-86
lines changed

docs/api/tools.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
pythainlp.tools
44
====================================
5-
The :class:`pythainlp.tools` is tool for pythainlp.
5+
The :class:`pythainlp.tools` contains miscellaneous functions for PyThaiNLP internal use.
66

77
Modules
88
-------
99

1010
.. autofunction:: get_full_data_path
1111
.. autofunction:: get_pythainlp_data_path
12-
.. autofunction:: get_pythainlp_path
12+
.. autofunction:: get_pythainlp_path

docs/api/translate.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
pythainlp.translate
44
===================
5-
The :class:`pythainlp.translate` for language translation.
5+
The :class:`pythainlp.translate` for machine translation.
66

77
Modules
88
-------
99

10-
.. autofunction:: translate
10+
.. autofunction:: download_model_all
11+
.. autoclass:: EnThTranslate
12+
:members: translate
13+
.. autoclass:: ThEnTranslate
14+
:members: translate

pythainlp/translate/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
Language translation.
44
"""
55

6-
__all__ = [
7-
"translate",
8-
"download_model_all"
9-
]
6+
__all__ = ["EnThTranslator", "ThEnTranslator", "download_model_all"]
107

11-
from pythainlp.translate.core import translate, download_model_all
8+
from pythainlp.translate.core import (
9+
EnThTranslator,
10+
ThEnTranslator,
11+
download_model_all,
12+
)

pythainlp/translate/core.py

Lines changed: 57 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,21 @@
99
from fairseq.models.transformer import TransformerModel
1010
from sacremoses import MosesTokenizer
1111

12-
_en_tokenizer = MosesTokenizer("en")
13-
14-
_model = None
15-
_model_name = None
1612

13+
_EN_TH_MODEL_NAME = "scb_1m_en-th_moses"
1714
# SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0.tar.gz
18-
_EN_TH_FILE_NAME = (
19-
"SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0"
20-
)
15+
_EN_TH_FILE_NAME = "SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0"
16+
17+
_TH_EN_MODEL_NAME = "scb_1m_th-en_spm"
2118
# SCB_1M-MT_OPUS+TBASE_th-en_spm-spm_32000-joined_v1.0.tar.gz
2219
_TH_EN_FILE_NAME = "SCB_1M-MT_OPUS+TBASE_th-en_spm-spm_32000-joined_v1.0"
2320

2421

25-
def _download_install(name):
22+
def _get_translate_path(model: str, *path: str) -> str:
23+
return os.path.join(get_full_data_path(model), *path)
24+
25+
26+
def _download_install(name: str) -> None:
2627
if get_corpus_path(name) is None:
2728
download(name, force=True, version="1.0")
2829
tar = tarfile.open(get_corpus_path(name), "r:gz")
@@ -36,92 +37,78 @@ def _download_install(name):
3637

3738
def download_model_all() -> None:
3839
"""
39-
Download Model
40+
Download all translation models in advanced
4041
"""
41-
_download_install("scb_1m_th-en_spm")
42-
_download_install("scb_1m_en-th_moses")
42+
_download_install(_EN_TH_MODEL_NAME)
43+
_download_install(_TH_EN_MODEL_NAME)
4344

4445

45-
def _get_translate_path(model: str, *path: str) -> str:
46-
return os.path.join(get_full_data_path(model), *path)
47-
46+
class EnThTranslator:
47+
def __init__(self):
48+
self._tokenizer = MosesTokenizer("en")
4849

49-
def _scb_en_th_model_init():
50-
global _model, _model_name
50+
self._model_name = _EN_TH_MODEL_NAME
5151

52-
if _model_name != "scb_1m_en-th_moses":
53-
del _model
54-
_model_name = "scb_1m_en-th_moses"
55-
_download_install(_model_name)
56-
_model = TransformerModel.from_pretrained(
52+
_download_install(self._model_name)
53+
self._model = TransformerModel.from_pretrained(
5754
model_name_or_path=_get_translate_path(
58-
_model_name, _EN_TH_FILE_NAME, "models",
55+
self._model_name,
56+
_EN_TH_FILE_NAME,
57+
"models",
5958
),
6059
checkpoint_file="checkpoint.pt",
6160
data_name_or_path=_get_translate_path(
62-
_model_name, _EN_TH_FILE_NAME, "vocab",
61+
self._model_name,
62+
_EN_TH_FILE_NAME,
63+
"vocab",
6364
),
6465
)
6566

67+
def translate(self, text: str) -> str:
68+
"""
69+
Translate text from English to Thai
6670
67-
def _scb_en_th_translate(text: str) -> str:
68-
global _model, _model_name
69-
70-
_scb_en_th_model_init()
71-
72-
tokens = " ".join(_en_tokenizer.tokenize(text))
73-
translated = _model.translate(tokens)
74-
return translated.replace(' ', '').replace('▁', ' ').strip()
71+
:param str text: input text in source language
72+
:return: translated text in target language
73+
:rtype: str
74+
"""
75+
tokens = " ".join(self._tokenizer.tokenize(text))
76+
translated = self._model.translate(tokens)
77+
return translated.replace(" ", "").replace("▁", " ").strip()
7578

7679

77-
def _scb_th_en_model_init():
78-
global _model, _model_name
80+
class ThEnTranslator:
81+
def __init__(self):
82+
self._model_name = _TH_EN_MODEL_NAME
7983

80-
if _model_name != "scb_1m_th-en_spm":
81-
del _model
82-
_model_name = "scb_1m_th-en_spm"
83-
_download_install(_model_name)
84-
_model = TransformerModel.from_pretrained(
84+
_download_install(self._model_name)
85+
self._model = TransformerModel.from_pretrained(
8586
model_name_or_path=_get_translate_path(
86-
_model_name, _TH_EN_FILE_NAME, "models",
87+
self._model_name,
88+
_TH_EN_FILE_NAME,
89+
"models",
8790
),
8891
checkpoint_file="checkpoint.pt",
8992
data_name_or_path=_get_translate_path(
90-
_model_name, _TH_EN_FILE_NAME, "vocab",
93+
self._model_name,
94+
_TH_EN_FILE_NAME,
95+
"vocab",
9196
),
9297
bpe="sentencepiece",
9398
sentencepiece_model=_get_translate_path(
94-
_model_name, _TH_EN_FILE_NAME, "bpe", "spm.th.model",
99+
self._model_name,
100+
_TH_EN_FILE_NAME,
101+
"bpe",
102+
"spm.th.model",
95103
),
96104
)
97105

106+
def translate(self, text: str) -> str:
107+
"""
108+
Translate text from Thai to English
98109
99-
def _scb_th_en_translate(text: str) -> str:
100-
global _model, _model_name
101-
102-
_scb_th_en_model_init()
103-
104-
return _model.translate(text)
105-
106-
107-
def translate(text: str, source: str, target: str) -> str:
108-
"""
109-
Translate Language
110-
111-
:param str text: input text in source language
112-
:param str source: source language ("en" or "th")
113-
:param str target: target language ("en" or "th")
114-
115-
:return: translated text in target language
116-
:rtype: str
117-
"""
118-
translated = None
119-
120-
if source == "th" and target == "en":
121-
translated = _scb_th_en_translate(text)
122-
elif source == "en" and target == "th":
123-
translated = _scb_en_th_translate(text)
124-
else:
125-
return ValueError("The combination of the arguments isn't allowed.")
126-
127-
return translated
110+
:param str text: input text in source language
111+
:return: translated text in target language
112+
:rtype: str
113+
"""
114+
return self._model.translate(text)

tests/test_translate.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,20 @@
22

33
import unittest
44

5-
from pythainlp.translate import translate
5+
from pythainlp.translate import EnThTranslator, ThEnTranslator
66

77

88
class TestTranslatePackage(unittest.TestCase):
99
def test_translate(self):
10+
self.th_en_translator = ThEnTranslator()
1011
self.assertIsNotNone(
11-
translate(
12+
self.th_en_translator.translate(
1213
"แมวกินปลา",
13-
source="th",
14-
target="en"
1514
)
1615
)
16+
self.en_th_translator = EnThTranslator()
1717
self.assertIsNotNone(
18-
translate(
18+
self.en_th_translator.translate(
1919
"the cat eats fish.",
20-
source="en",
21-
target="th"
2220
)
2321
)

0 commit comments

Comments
 (0)