1+ # -*- coding: utf-8 -*-
2+ from datasets import load_dataset
3+
4+ # transformers
5+ from transformers import (
6+ CamembertTokenizer ,
7+ AutoTokenizer ,
8+ AutoModel ,
9+ AutoModelForMaskedLM ,
10+ AutoModelForSequenceClassification ,
11+ AutoModelForTokenClassification ,
12+ TrainingArguments ,
13+ Trainer ,
14+ pipeline ,
15+ )
16+ import random
17+ from typing import List
18+ import thai2transformers
19+ from thai2transformers .preprocess import process_transformers
20+
21+ model_name = "airesearch/wangchanberta-base-att-spm-uncased"
22+
23+
24+ class Thai2transformersAug :
25+ def __init__ (self ):
26+ self .model_name = "airesearch/wangchanberta-base-att-spm-uncased"
27+ self .target_tokenizer = CamembertTokenizer
28+ self .tokenizer = CamembertTokenizer .from_pretrained (
29+ self .model_name ,
30+ revision = 'main' )
31+ self .tokenizer .additional_special_tokens = ['<s>NOTUSED' , '</s>NOTUSED' , '<_>' ]
32+ self .fill_mask = pipeline (
33+ task = 'fill-mask' ,
34+ tokenizer = self .tokenizer ,
35+ model = f'{ self .model_name } ' ,
36+ revision = 'main' ,)
37+ def generate (self , sentence : str , num_replace_tokens : int = 3 ):
38+ self .sent2 = []
39+ self .input_text = process_transformers (sentence )
40+ sent = [i for i in self .tokenizer .tokenize (self .input_text ) if i != '▁' ]
41+ if len (sent ) < num_replace_tokens :
42+ num_replace_tokens = len (sent )
43+ masked_text = self .input_text
44+ for i in range (num_replace_tokens ):
45+ replace_token = [sent .pop (random .randrange (len (sent ))) for _ in range (1 )][0 ]
46+ masked_text = masked_text .replace (replace_token , f"{ self .fill_mask .tokenizer .mask_token } " ,1 )
47+ self .sent2 += [str (j ['sequence' ]).replace ('<s> ' ,'' ).replace ('</s>' ,'' ) for j in self .fill_mask (masked_text + '<pad>' ) if j ['sequence' ] not in self .sent2 ]
48+ masked_text = self .input_text
49+ return self .sent2
50+
51+ def augment (self , sentence : str , num_replace_tokens : int = 3 ) -> List [str ]:
52+ """
53+ Text Augment from wangchanberta
54+
55+ :param str sentence: thai sentence
56+ :param int num_replace_tokens: number replace tokens
57+
58+ :return: list of text augment
59+ :rtype: List[str]
60+ """
61+ self .sent2 = []
62+ try :
63+ self .sent2 = self .generate (sentence , num_replace_tokens )
64+ if self .sent2 == []:
65+ self .sent2 = self .generate (sentence , num_replace_tokens )
66+ return self .sent2
67+ except :
68+ if len (self .sent2 ) > 0 :
69+ return self .sent2
70+ else :
71+ return self .sent2
0 commit comments