-
Notifications
You must be signed in to change notification settings - Fork 2
/
gen_embedding_sincse.py
86 lines (68 loc) · 2.39 KB
/
gen_embedding_sincse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer
import argparse
import json
from tqdm import trange
import numpy as np
import pickle
# Tokenize input texts
def get_arguments():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--dataset",
default='agnews',
type=str,
required=True,
help="The input data dir. Should contain the cached passage and query files",
)
parser.add_argument(
"--type",
default='unlabeled',
type=str,
help="The input data dir. Should contain the cached passage and query files",
)
parser.add_argument(
"--batch_size",
default=64,
type=int,
help="The input data dir. Should contain the cached passage and query files",
)
parser.add_argument(
"--gpuid",
default=0,
type=int,
help="The input data dir. Should contain the cached passage and query files",
)
args = parser.parse_args()
return args
args = get_arguments()
text = []
label = []
model_name = "princeton-nlp/unsup-simcse-roberta-base"
text_a = []
text_b = []
with open(f"{args.dataset}/{args.type}.json", 'r') as f:
for lines in f:
lines = json.loads(lines)
text.append(lines["text"])
label.append(lines["_id"])
# print number of unlabeled data/classes
print(len(text), len(label))
# Import our models. The package will take care of downloading the models automatically
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model = model.to(f"cuda:{args.gpuid}")
embedding = []
num_iter = len(text)//args.batch_size if len(text) % args.batch_size == 0 else (len(text)//args.batch_size + 1)
for i in trange(len(text)//args.batch_size + 1):
inputs = tokenizer(text[i*args.batch_size:(i+1)*args.batch_size], padding=True, truncation=True, return_tensors="pt").to(f"cuda:{args.gpuid}")
# Get the embeddings
with torch.no_grad():
embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
embedding.append(embeddings.cpu().numpy())
embedding = np.concatenate(embedding, axis = 0)
print(embedding.shape)
with open(f"{args.dataset}/embedding_{args.model}_simcse_{args.type}.pkl", 'wb') as handle:
pickle.dump(embedding, handle, protocol=4)