-
Notifications
You must be signed in to change notification settings - Fork 0
/
utilities.py
50 lines (35 loc) · 1.14 KB
/
utilities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import pickle
import torch
import torch.nn as nn
from transformer import Transformer_Model
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
model = Transformer_Model()
model = torch.jit.load('models/text_classification-0.1.0.pth')
model.eval()
global_vectors = GloVe(name='6B', dim=50)
tokenizer = get_tokenizer("basic_english")
padding = global_vectors.get_vecs_by_tokens(["."], lower_case_backup=True)
def preprocess(text):
e = global_vectors.get_vecs_by_tokens(tokenizer(text), lower_case_backup=True)
embeds = np.empty((100, 50))
if(len(e) <= 100):
embeds[0:len(e),:] = e
embeds[len(e):100, :] = padding
else:
embeds = e[0:100,:]
embeds = torch.tensor(embeds).float()
embeds = torch.reshape(embeds, (1, 100, 50)).float()
return embeds
def predict(embeddings):
pred = model(embeddings)
pred = torch.mean(pred)
if(pred >= 0.5):
return "Human written text"
else:
return "ChatGPT generated text"
def predict_pipeline(text):
embeds = preprocess(text)
prediction = predict(embeds)
return prediction