Skip to content

Commit ff74b39

Browse files
authored
Merge pull request #873 from pavaris-pm/dev
Add PhayaThaiBERT engine with new features [WIP] by @pavaris-pm and @MpolaarbearM
2 parents de4f206 + e7ef6ce commit ff74b39

File tree

9 files changed

+537
-7
lines changed

9 files changed

+537
-7
lines changed

pythainlp/augment/lm/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
33
# SPDX-License-Identifier: Apache-2.0
44
"""
5-
LM
5+
Language Models
66
"""
77

88
__all__ = [
99
"FastTextAug",
1010
"Thai2transformersAug",
11+
"ThaiTextAugmenter",
1112
]
1213

1314
from pythainlp.augment.lm.fasttext import FastTextAug
15+
from pythainlp.augment.lm.phayathaibert import ThaiTextAugmenter
1416
from pythainlp.augment.lm.wangchanberta import Thai2transformersAug

pythainlp/augment/lm/phayathaibert.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import List
6+
import random
7+
import re
8+
9+
from pythainlp.phayathaibert.core import ThaiTextProcessor
10+
11+
12+
_MODEL_NAME = "clicknext/phayathaibert"
13+
14+
15+
class ThaiTextAugmenter:
16+
def __init__(self,) -> None:
17+
from transformers import (AutoTokenizer,
18+
AutoModelForMaskedLM,
19+
pipeline,)
20+
self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
21+
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME)
22+
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm)
23+
self.processor = ThaiTextProcessor()
24+
25+
def generate(self,
26+
sample_text: str,
27+
word_rank: int,
28+
max_length: int = 3,
29+
sample: bool = False
30+
) -> str:
31+
sample_txt = sample_text
32+
final_text = ""
33+
34+
for j in range(max_length):
35+
input = self.processor.preprocess(sample_txt)
36+
if sample:
37+
random_word_idx = random.randint(0, 4)
38+
output = self.model(input)[random_word_idx]["sequence"]
39+
else:
40+
output = self.model(input)[word_rank]["sequence"]
41+
sample_txt = output + "<mask>"
42+
final_text = sample_txt
43+
44+
gen_txt = re.sub("<mask>", "", final_text)
45+
46+
return gen_txt
47+
48+
def augment(self,
49+
text: str,
50+
num_augs: int = 3,
51+
sample: bool = False
52+
) -> List[str]:
53+
"""
54+
Text augmentation from PhayaThaiBERT
55+
56+
:param str text: Thai text
57+
:param int num_augs: an amount of augmentation text needed as an output
58+
:param bool sample: whether to sample the text as an output or not, \
59+
true if more word diversity is needed
60+
61+
:return: list of text augment
62+
:rtype: List[str]
63+
64+
:Example:
65+
::
66+
67+
from pythainlp.augment.lm import ThaiTextAugmenter
68+
69+
aug = ThaiTextAugmenter()
70+
aug.augment("ช้างมีทั้งหมด 50 ตัว บน", num_args=5)
71+
72+
# output = ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้ครับ.',
73+
'ช้างมีทั้งหมด 50 ตัว บนพื้นดินครับ...',
74+
'ช้างมีทั้งหมด 50 ตัว บนท้องฟ้าครับ...',
75+
'ช้างมีทั้งหมด 50 ตัว บนดวงจันทร์.‼',
76+
'ช้างมีทั้งหมด 50 ตัว บนเขาค่ะ😁']
77+
"""
78+
MAX_NUM_AUGS = 5
79+
augment_list = []
80+
81+
if "<mask>" not in text:
82+
text = text + "<mask>"
83+
84+
if num_augs <= MAX_NUM_AUGS:
85+
for rank in range(num_augs):
86+
gen_text = self.generate(text, rank, sample=sample)
87+
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text))
88+
augment_list.append(processed_text)
89+
90+
return augment_list
91+
92+
raise ValueError(
93+
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
94+
)

pythainlp/phayathaibert/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# -*- coding: utf-8 -*-
2+
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
3+
# SPDX-License-Identifier: Apache-2.0
4+
__all__ = [
5+
"NamedEntityTagger",
6+
"PartOfSpeechTagger",
7+
"ThaiTextAugmenter",
8+
"ThaiTextProcessor",
9+
"segment",
10+
]
11+
12+
from pythainlp.phayathaibert.core import (
13+
NamedEntityTagger,
14+
PartOfSpeechTagger,
15+
ThaiTextAugmenter,
16+
ThaiTextProcessor,
17+
segment,
18+
)

0 commit comments

Comments
 (0)