Skip to content

phayathaibert, khavee, parse: Code clean up #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions pythainlp/augment/lm/phayathaibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@


class ThaiTextAugmenter:
def __init__(self,) -> None:
from transformers import (AutoTokenizer,
AutoModelForMaskedLM,
pipeline,)
def __init__(self) -> None:
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
pipeline,
)

self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME)
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(
_MODEL_NAME
)
self.model = pipeline(
"fill-mask",
tokenizer=self.tokenizer,
model=self.model_for_masked_lm,
)
self.processor = ThaiTextProcessor()

def generate(self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False
) -> str:
def generate(
self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False,
) -> str:
sample_txt = sample_text
final_text = ""

Expand All @@ -45,11 +55,9 @@ def generate(self,

return gen_txt

def augment(self,
text: str,
num_augs: int = 3,
sample: bool = False
) -> List[str]:
def augment(
self, text: str, num_augs: int = 3, sample: bool = False
) -> List[str]:
"""
Text augmentation from PhayaThaiBERT

Expand Down Expand Up @@ -84,11 +92,14 @@ def augment(self,
if num_augs <= MAX_NUM_AUGS:
for rank in range(num_augs):
gen_text = self.generate(text, rank, sample=sample)
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text))
processed_text = re.sub(
"<_>", " ", self.processor.preprocess(gen_text)
)
augment_list.append(processed_text)
else:
raise ValueError(
f"augmentation of more than {num_augs} is exceeded \
the default limit: {MAX_NUM_AUGS}"
)

return augment_list

raise ValueError(
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
)
return augment_list
8 changes: 5 additions & 3 deletions pythainlp/augment/lm/wangchanberta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

from typing import List

from transformers import (
CamembertTokenizer,
pipeline,
Expand Down Expand Up @@ -51,9 +53,9 @@ def generate(self, sentence: str, num_replace_tokens: int = 3):

def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:
"""
Text Augment from wangchanberta
Text augmentation from WangchanBERTa

:param str sentence: thai sentence
:param str sentence: Thai sentence
:param int num_replace_tokens: number replace tokens

:return: list of text augment
Expand All @@ -64,7 +66,7 @@ def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:

from pythainlp.augment.lm import Thai2transformersAug

aug=Thai2transformersAug()
aug = Thai2transformersAug()

aug.augment("ช้างมีทั้งหมด 50 ตัว บน")
# output: ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้',
Expand Down
1 change: 1 addition & 0 deletions pythainlp/augment/wordnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import OrderedDict
import itertools
from typing import List

from nltk.corpus import wordnet as wn
from pythainlp.corpus import wordnet
from pythainlp.tokenize import word_tokenize
Expand Down
1 change: 1 addition & 0 deletions pythainlp/khavee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

__all__ = ["KhaveeVerifier"]

from pythainlp.khavee.core import KhaveeVerifier
Loading