-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a66da67
commit 3edcc58
Showing
7 changed files
with
515 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,5 @@ data/* | |
data.tar.gz | ||
.ipynb_checkpoints | ||
__pycache__ | ||
outputs | ||
outputs | ||
playground |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# 출처 : https://github.com/boostcampaitech3/level2-mrc-level2-nlp-11 | ||
|
||
import json | ||
import torch.nn.functional as F | ||
from model import * | ||
from tokenizer import * | ||
import logging | ||
import sys | ||
from typing import Callable, Dict, List, NoReturn, Tuple | ||
import torch | ||
import numpy as np | ||
from transformers import AutoTokenizer | ||
|
||
|
||
def main(): | ||
epoch = 6 | ||
MODEL_NAME = "klue/bert-base" | ||
|
||
dataset = load_data("/opt/ml/input/data/train.csv") | ||
val_dataset = dataset[3952:] | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
model = ColbertModel.from_pretrained(MODEL_NAME) | ||
model.resize_token_embeddings(tokenizer.vocab_size + 2) | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
model.to(device) | ||
|
||
model.load_state_dict( | ||
torch.load(f"/opt/ml/input/code/colbert/best_model/colbert_epoch{epoch}.pth") | ||
) | ||
|
||
print("opening wiki passage...") | ||
with open("/opt/ml/input/data/wikipedia_documents.json", "r", encoding="utf-8") as f: | ||
wiki = json.load(f) | ||
context = list(dict.fromkeys([v["text"] for v in wiki.values()])) | ||
print("wiki loaded!!!") | ||
|
||
query = list(val_dataset["query"]) | ||
ground_truth = list(val_dataset["context"]) | ||
|
||
batched_p_embs = [] | ||
with torch.no_grad(): | ||
|
||
model.eval() | ||
|
||
# 토크나이저 | ||
q_seqs_val = tokenize_colbert(query, tokenizer, corpus="query").to("cuda") | ||
q_emb = model.query(**q_seqs_val).to("cpu") | ||
|
||
print(q_emb.size()) | ||
|
||
print("Start passage embedding......") | ||
p_embs = [] | ||
for step, p in enumerate(tqdm(context)): | ||
p = tokenize_colbert(p, tokenizer, corpus="doc").to("cuda") | ||
p_emb = model.doc(**p).to("cpu").numpy() | ||
p_embs.append(p_emb) | ||
if (step + 1) % 200 == 0: | ||
batched_p_embs.append(p_embs) | ||
p_embs = [] | ||
batched_p_embs.append(p_embs) | ||
|
||
print("passage tokenizing done!!!!") | ||
length = len(val_dataset["context"]) | ||
|
||
dot_prod_scores = model.get_score(q_emb, batched_p_embs, eval=True) | ||
|
||
print(dot_prod_scores.size()) | ||
|
||
rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze() | ||
print(dot_prod_scores) | ||
print(rank) | ||
print(rank.size()) | ||
torch.save(rank, f"/opt/ml/input/code/colbert/rank/rank_epoch{epoch}.pth") | ||
|
||
k = 100 | ||
score = 0 | ||
|
||
for idx in range(length): | ||
print(dot_prod_scores[idx]) | ||
print(rank[idx]) | ||
print() | ||
for i in range(k): | ||
if ground_truth[idx] == context[rank[idx][i]]: | ||
score += 1 | ||
|
||
print(f"{score} over {length} context found!!") | ||
print(f"final score is {score/length}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# 출처 : https://github.com/boostcampaitech3/level2-mrc-level2-nlp-11 | ||
|
||
|
||
def run_colbert_retrieval(datasets): | ||
test_dataset = datasets["validation"].flatten_indices().to_pandas() | ||
MODEL_NAME = "klue/bert-base" | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
model_config = AutoConfig.from_pretrained(MODEL_NAME) | ||
special_tokens = {"additional_special_tokens": ["[Q]", "[D]"]} | ||
ret_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
ret_tokenizer.add_special_tokens(special_tokens) | ||
model = ColbertModel.from_pretrained(MODEL_NAME) | ||
model.resize_token_embeddings(ret_tokenizer.vocab_size + 2) | ||
|
||
model.to(device) | ||
|
||
model.load_state_dict(torch.load("/opt/ml/input/code/dense_model/colbert_epoch12.pth")) | ||
|
||
print("opening wiki passage...") | ||
with open("/opt/ml/input/data/wikipedia_documents.json", "r", encoding="utf-8") as f: | ||
wiki = json.load(f) | ||
context = list(dict.fromkeys([v["text"] for v in wiki.values()])) | ||
print("wiki loaded!!!") | ||
|
||
query = list(test_dataset["question"]) | ||
mrc_ids = test_dataset["id"] | ||
length = len(test_dataset) | ||
|
||
batched_p_embs = [] | ||
with torch.no_grad(): | ||
model.eval | ||
|
||
q_seqs_val = tokenize_colbert(query, ret_tokenizer, corpus="query").to("cuda") | ||
q_emb = model.query(**q_seqs_val).to("cpu") | ||
print(q_emb.size()) | ||
|
||
print(q_emb.size()) | ||
|
||
print("Start passage embedding......") | ||
p_embs = [] | ||
for step, p in enumerate(tqdm(context)): | ||
p = tokenize_colbert(p, ret_tokenizer, corpus="doc").to("cuda") | ||
p_emb = model.doc(**p).to("cpu").numpy() | ||
p_embs.append(p_emb) | ||
if (step + 1) % 200 == 0: | ||
batched_p_embs.append(p_embs) | ||
p_embs = [] | ||
batched_p_embs.append(p_embs) | ||
|
||
dot_prod_scores = model.get_score(q_emb, batched_p_embs, eval=True) | ||
print(dot_prod_scores.size()) | ||
|
||
rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze() | ||
print(dot_prod_scores) | ||
print(rank) | ||
torch.save(rank, "/opt/ml/input/code/inferecne_colbert_rank.pth") | ||
print(rank.size()) | ||
|
||
k = 100 | ||
passages = [] | ||
|
||
for idx in range(length): | ||
passage = "" | ||
for i in range(k): | ||
passage += context[rank[idx][i]] | ||
passage += " " | ||
passages.append(passage) | ||
|
||
df = pd.DataFrame({"question": query, "id": mrc_ids, "context": passages}) | ||
f = Features( | ||
{ | ||
"context": Value(dtype="string", id=None), | ||
"id": Value(dtype="string", id=None), | ||
"question": Value(dtype="string", id=None), | ||
} | ||
) | ||
|
||
complete_datasets = DatasetDict({"validation": Dataset.from_pandas(df, features=f)}) | ||
return complete_datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# 출처 : https://github.com/boostcampaitech3/level2-mrc-level2-nlp-11 | ||
|
||
import pandas as pd | ||
import torch.nn as nn | ||
import numpy as np | ||
from tqdm import tqdm, trange | ||
import argparse | ||
import random | ||
import torch | ||
import torch.nn.functional as F | ||
from transformers import ( | ||
AutoModel, | ||
BertModel, | ||
BertPreTrainedModel, | ||
AdamW, | ||
TrainingArguments, | ||
get_linear_schedule_with_warmup, | ||
) | ||
|
||
|
||
class ColbertModel(BertPreTrainedModel): | ||
def __init__(self, config): | ||
super(ColbertModel, self).__init__(config) | ||
|
||
# BertModel 사용 | ||
self.similarity_metric = "cosine" | ||
self.dim = 128 | ||
self.batch = 8 | ||
self.bert = BertModel(config) | ||
self.init_weights() | ||
self.linear = nn.Linear(config.hidden_size, self.dim, bias=False) | ||
|
||
def forward(self, p_inputs, q_inputs): | ||
Q = self.query(**q_inputs) | ||
D = self.doc(**p_inputs) | ||
return self.get_score(Q, D) | ||
|
||
def query(self, input_ids, attention_mask, token_type_ids): | ||
Q = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0] | ||
Q = self.linear(Q) | ||
return torch.nn.functional.normalize(Q, p=2, dim=2) | ||
|
||
def doc(self, input_ids, attention_mask, token_type_ids): | ||
D = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0] | ||
D = self.linear(D) | ||
return torch.nn.functional.normalize(D, p=2, dim=2) | ||
|
||
def get_score(self, Q, D, eval=False): | ||
if eval: | ||
if self.similarity_metric == "cosine": | ||
final_score = torch.tensor([]) | ||
for D_batch in tqdm(D): | ||
D_batch = torch.Tensor(D_batch).squeeze() | ||
p_seqeunce_output = D_batch.transpose( | ||
1, 2 | ||
) # (batch_size,hidden_size,p_sequence_length) | ||
q_sequence_output = Q.view( | ||
240, 1, -1, self.dim | ||
) # (batch_size, 1, q_sequence_length, hidden_size) | ||
dot_prod = torch.matmul( | ||
q_sequence_output, p_seqeunce_output | ||
) # (batch_size,batch_size, q_sequence_length, p_seqence_length) | ||
max_dot_prod_score = torch.max(dot_prod, dim=3)[ | ||
0 | ||
] # (batch_size,batch_size,q_sequnce_length) | ||
score = torch.sum(max_dot_prod_score, dim=2) # (batch_size,batch_size) | ||
final_score = torch.cat([final_score, score], dim=1) | ||
print(final_score.size()) | ||
return final_score | ||
|
||
else: | ||
if self.similarity_metric == "cosine": | ||
|
||
p_seqeunce_output = D.transpose(1, 2) # (batch_size,hidden_size,p_sequence_length) | ||
q_sequence_output = Q.view( | ||
self.batch, 1, -1, self.dim | ||
) # (batch_size, 1, q_sequence_length, hidden_size) | ||
dot_prod = torch.matmul( | ||
q_sequence_output, p_seqeunce_output | ||
) # (batch_size,batch_size, q_sequence_length, p_seqence_length) | ||
max_dot_prod_score = torch.max(dot_prod, dim=3)[ | ||
0 | ||
] # (batch_size,batch_size,q_sequnce_length) | ||
final_score = torch.sum(max_dot_prod_score, dim=2) # (batch_size,batch_size) | ||
|
||
return final_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# 출처 : https://github.com/boostcampaitech3/level2-mrc-level2-nlp-11 | ||
|
||
import pandas as pd | ||
import torch.nn as nn | ||
import numpy as np | ||
from tqdm import tqdm, trange | ||
import argparse | ||
import random | ||
import torch | ||
import torch.nn.functional as F | ||
from transformers import ( | ||
AutoTokenizer, | ||
AutoModel, | ||
BertModel, | ||
BertPreTrainedModel, | ||
AdamW, | ||
TrainingArguments, | ||
get_linear_schedule_with_warmup, | ||
) | ||
|
||
|
||
def load_data(datadir): | ||
dataset = pd.read_csv(datadir) | ||
dataset = pd.DataFrame( | ||
{"context": dataset["context"], "query": dataset["question"], "title": dataset["title"]} | ||
) | ||
return dataset | ||
|
||
|
||
def load_tokenizer(MODEL_NAME): | ||
special_tokens = {"additional_special_tokens": ["[Q]", "[D]"]} | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | ||
tokenizer.add_special_tokens(special_tokens) | ||
return tokenizer | ||
|
||
|
||
def tokenize_colbert(dataset, tokenizer, corpus): | ||
|
||
# for inference | ||
if corpus == "query": | ||
preprocessed_data = [] | ||
for query in dataset: | ||
preprocessed_data.append("[Q] " + query) | ||
|
||
tokenized_query = tokenizer( | ||
preprocessed_data, return_tensors="pt", padding=True, truncation=True, max_length=128 | ||
) | ||
return tokenized_query | ||
|
||
elif corpus == "doc": | ||
preprocessed_data = "[D] " + dataset | ||
tokenized_context = tokenizer( | ||
preprocessed_data, | ||
return_tensors="pt", | ||
padding="max_length", | ||
truncation=True, | ||
) | ||
|
||
return tokenized_context | ||
|
||
# for train | ||
else: | ||
preprocessed_query = [] | ||
preprocessed_context = [] | ||
for query, context in zip(dataset["query"], dataset["context"]): | ||
preprocessed_context.append("[D] " + context) | ||
preprocessed_query.append("[Q] " + query) | ||
tokenized_query = tokenizer( | ||
preprocessed_query, return_tensors="pt", padding=True, truncation=True, max_length=128 | ||
) | ||
|
||
tokenized_context = tokenizer( | ||
preprocessed_context, | ||
return_tensors="pt", | ||
padding="max_length", | ||
truncation=True, | ||
) | ||
return tokenized_context, tokenized_query |
Oops, something went wrong.