99from fairseq .models .transformer import TransformerModel
1010from sacremoses import MosesTokenizer
1111
12- _en_tokenizer = MosesTokenizer ("en" )
1312
13+ _EN_TH_MODEL_NAME = "scb_1m_en-th_moses"
1414# SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0.tar.gz
15- _EN_TH_FILE_NAME = (
16- "SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0"
17- )
15+ _EN_TH_FILE_NAME = "SCB_1M-MT_OPUS+TBASE_en-th_moses-spm_130000-16000_v1.0"
16+
17+ _TH_EN_MODEL_NAME = "scb_1m_th-en_spm"
1818# SCB_1M-MT_OPUS+TBASE_th-en_spm-spm_32000-joined_v1.0.tar.gz
1919_TH_EN_FILE_NAME = "SCB_1M-MT_OPUS+TBASE_th-en_spm-spm_32000-joined_v1.0"
2020
2121
22+ def _get_translate_path (model : str , * path : str ) -> str :
23+ return os .path .join (get_full_data_path (model ), * path )
24+
25+
2226def _download_install (name : str ) -> None :
2327 if get_corpus_path (name ) is None :
2428 download (name , force = True , version = "1.0" )
@@ -33,81 +37,79 @@ def _download_install(name: str) -> None:
3337
3438def download_model_all () -> None :
3539 """
36- Download Model
40+ Download all translation models in advanced
3741 """
38- _download_install ("scb_1m_th-en_spm" )
39- _download_install ("scb_1m_en-th_moses" )
42+ _download_install (_EN_TH_MODEL_NAME )
43+ _download_install (_TH_EN_MODEL_NAME )
4044
4145
42- def _get_translate_path (model : str , * path : str ) -> str :
43- return os .path .join (get_full_data_path (model ), * path )
44-
45-
46- class Translate :
46+ class EnThTranslator :
4747 def __init__ (self ):
48- self ._model = None
49- self ._model_name = None
50-
51- def _scb_en_th_model_init (self ):
52- if self ._model_name != "scb_1m_en-th_moses" :
53- self ._model_name = "scb_1m_en-th_moses"
54- _download_install (self ._model_name )
55- self ._model = TransformerModel .from_pretrained (
56- model_name_or_path = _get_translate_path (
57- self ._model_name , _EN_TH_FILE_NAME , "models" ,
58- ),
59- checkpoint_file = "checkpoint.pt" ,
60- data_name_or_path = _get_translate_path (
61- self ._model_name , _EN_TH_FILE_NAME , "vocab" ,
62- ),
63- )
64-
65- def _scb_en_th_translate (self , text : str ) -> str :
66- self ._scb_en_th_model_init ()
67- tokens = " " .join (_en_tokenizer .tokenize (text ))
48+ self ._tokenizer = MosesTokenizer ("en" )
49+
50+ self ._model_name = _EN_TH_MODEL_NAME
51+
52+ _download_install (self ._model_name )
53+ self ._model = TransformerModel .from_pretrained (
54+ model_name_or_path = _get_translate_path (
55+ self ._model_name ,
56+ _EN_TH_FILE_NAME ,
57+ "models" ,
58+ ),
59+ checkpoint_file = "checkpoint.pt" ,
60+ data_name_or_path = _get_translate_path (
61+ self ._model_name ,
62+ _EN_TH_FILE_NAME ,
63+ "vocab" ,
64+ ),
65+ )
66+
67+ def translate (self , text : str ) -> str :
68+ """
69+ Translate text from English to Thai
70+
71+ :param str text: input text in source language
72+ :return: translated text in target language
73+ :rtype: str
74+ """
75+ tokens = " " .join (self ._tokenizer .tokenize (text ))
6876 translated = self ._model .translate (tokens )
69- return translated .replace (' ' , '' ).replace ('▁' , ' ' ).strip ()
70-
71- def _scb_th_en_model_init (self ):
72- if self ._model_name != "scb_1m_th-en_spm" :
73- self ._model_name = "scb_1m_th-en_spm"
74- _download_install (self ._model_name )
75- self ._model = TransformerModel .from_pretrained (
76- model_name_or_path = _get_translate_path (
77- self ._model_name , _TH_EN_FILE_NAME , "models" ,
78- ),
79- checkpoint_file = "checkpoint.pt" ,
80- data_name_or_path = _get_translate_path (
81- self ._model_name , _TH_EN_FILE_NAME , "vocab" ,
82- ),
83- bpe = "sentencepiece" ,
84- sentencepiece_model = _get_translate_path (
85- self ._model_name , _TH_EN_FILE_NAME , "bpe" , "spm.th.model" ,
86- ),
87- )
88-
89- def _scb_th_en_translate (self , text : str ) -> str :
90- self ._scb_th_en_model_init ()
91- return self ._model .translate (text )
77+ return translated .replace (" " , "" ).replace ("▁" , " " ).strip ()
78+
9279
93- def translate (self , text : str , source : str , target : str ) -> str :
80+ class ThEnTranslator :
81+ def __init__ (self ):
82+ self ._model_name = _TH_EN_MODEL_NAME
83+
84+ _download_install (self ._model_name )
85+ self ._model = TransformerModel .from_pretrained (
86+ model_name_or_path = _get_translate_path (
87+ self ._model_name ,
88+ _TH_EN_FILE_NAME ,
89+ "models" ,
90+ ),
91+ checkpoint_file = "checkpoint.pt" ,
92+ data_name_or_path = _get_translate_path (
93+ self ._model_name ,
94+ _TH_EN_FILE_NAME ,
95+ "vocab" ,
96+ ),
97+ bpe = "sentencepiece" ,
98+ sentencepiece_model = _get_translate_path (
99+ self ._model_name ,
100+ _TH_EN_FILE_NAME ,
101+ "bpe" ,
102+ "spm.th.model" ,
103+ ),
104+ )
105+
106+ def translate (self , text : str ) -> str :
94107 """
95- Translate Language
108+ Translate text from Thai to English
96109
97110 :param str text: input text in source language
98- :param str source: source language ("en" or "th")
99- :param str target: target language ("en" or "th")
100-
101111 :return: translated text in target language
102112 :rtype: str
103113 """
104- translated = None
105- if source == "th" and target == "en" :
106- translated = self ._scb_th_en_translate (text )
107- elif source == "en" and target == "th" :
108- translated = self ._scb_en_th_translate (text )
109- else :
110- return ValueError (
111- "The combination of the arguments isn't allowed."
112- )
113- return translated
114+ self ._model_init ()
115+ return self ._model .translate (text )
0 commit comments