Skip to content

Commit

Permalink
Merge pull request #21 from boostcampaitech4lv23nlp2/develop_hyunsoo
Browse files Browse the repository at this point in the history
Merge
  • Loading branch information
khs0415p authored Jan 6, 2023
2 parents 94954e2 + 09a3447 commit 7d8bcfd
Show file tree
Hide file tree
Showing 6 changed files with 498 additions and 409 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ data.tar.gz
.ipynb_checkpoints
__pycache__
outputs
*.out
*.json
elastic_setting.py
elastic_test.py
Expand Down
2 changes: 1 addition & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ModelArguments:
"""

model_name_or_path: str = field(
default="klue/bert-base",
default="monologg/koelectra-base-v3-discriminator",
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
},
Expand Down
42 changes: 19 additions & 23 deletions es_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def create_index(self, index_name: str, setting_path: str = "./settings.json"):
self.client.indices.delete(index=index_name)

else:
return
return False

self.client.indices.create(index=index_name, body=settings)

print(f"Create an Index ({index_name})")
return True

def get_indices(self):
indices = list(self.client.indices.get_alias().keys())
Expand All @@ -86,7 +86,7 @@ def delete_index(self, index_name: str):
def insert_data(
self,
index_name: str,
data_path: str = "../data/deduplication_wikipedia_documents.json",
data_path: str = "../data/wikipedia_documents.json",
):
"""_summary_
Expand All @@ -100,14 +100,12 @@ def insert_data(

docs = []
print("Data Loding...")
for k, v in data.items():
for i, v in enumerate(data.values()):
doc = {
"_index": index_name,
"_type": "_doc",
"_id": k,
"document_id": v["document_id"],
"_id": i,
"text": v["text"],
"corpus_source": v["corpus_source"],
"title": v["title"],
}

Expand All @@ -129,6 +127,13 @@ def delete_data(self, index_name: str, doc_id):
self.client.delete(index=index_name, id=doc_id)

print(f"Deleted {doc_id} document.")

def init_index(self, index_name: str):
if self.client.indices.exists(index=index_name):
self.delete_index(index_name=index_name)

self.create_index(index_name=index_name)
print(f"Initialization...({index_name})")

def document_count(self, index_name: str):

Expand All @@ -139,30 +144,21 @@ def search(self, index_name: str, question: str, topk: int = 10):

body = {"query": {"bool": {"must": [{"match": {"text": question}}]}}}

responses = self.client.search(index=index_name, body=body, size=topk)["hits"][
"hits"
]
outputs = [
{"text": res["_source"]["text"], "score": res["_score"]}
for res in responses
]
responses = self.client.search(index=index_name, body=body, size=topk)["hits"]["hits"]

return outputs
return responses


if __name__ == "__main__":

es = ElasticObject("localhost:9200")
es.create_index("wiki_docs")
es.create_index("wiki_docs")
es.delete_index("wiki_docs")
es.create_index("wiki_docs")
es.insert_data("wiki_docs")
es.document_count("wiki_docs")
# es.create_index("wiki_docs")
# es.insert_data("wiki_docs")
# print(es.document_count("wiki_docs"))

outputs = es.search("wiki_docs", "소백산맥의 동남부에 위치한 지역은?")

for output in outputs:
print("doc:", output["text"])
print("score:", output["score"])
print("doc:", output['_source']["text"])
print("score:", output["_score"])
print()
7 changes: 6 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
load_metric,
)
from transformers import AutoTokenizer, AutoModelForMaskedLM
from retrieval import TfidfRetrieval, BM25
from retrieval import TfidfRetrieval, BM25, ElasticRetrieval
from trainer_qa import QuestionAnsweringTrainer
from colbert.inference import run_colbert_retrieval
from transformers import (
Expand Down Expand Up @@ -124,6 +124,9 @@ def run_sparse_retrieval(
retriever = TfidfRetrieval(
tokenize_fn=tokenize_fn, data_path=data_path, context_path=context_path
)

elif data_args.retrieval_choice=="elastic":
retriever = ElasticRetrieval(host='localhost', port='9200')
retriever.get_sparse_embedding()

if data_args.use_faiss:
Expand All @@ -143,6 +146,7 @@ def run_sparse_retrieval(

# train data 에 대해선 정답이 존재하므로 id question context answer 로 데이터셋이 구성됩니다.
elif training_args.do_eval:
df = df.drop(columns=["original_context"])
f = Features(
{
"answers": Sequence(
Expand Down Expand Up @@ -172,6 +176,7 @@ def run_mrc(
model,
) -> NoReturn:
print(datasets["validation"])

# eval 혹은 prediction에서만 사용함
column_names = datasets["validation"].column_names

Expand Down
123 changes: 105 additions & 18 deletions retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.auto import tqdm
from rank_bm25 import BM25Okapi

from es_retrieval import ElasticObject

@contextmanager
def timer(name):
Expand Down Expand Up @@ -501,25 +501,123 @@ def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 1) -> Tuple[Li
return doc_scores, doc_indices



class ElasticRetrieval:
def __init__(self, host='localhost', port='9200') -> NoReturn:
self.host = host
self.port = port
self.elastic_client = ElasticObject(host=self.host, port=self.port)

def retrieve(
self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
) -> Union[Tuple[List, List], pd.DataFrame]:

if isinstance(query_or_dataset, str):
doc_scores, doc_indices, responses = self.get_relevant_doc(query_or_dataset, k=topk)
print("[Search query]\n", query_or_dataset, "\n")

for i in range(min(topk, len(responses))):
print(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
print(doc_indices[i])
print(responses[i]['_source']['text'])

return (doc_scores, [doc_indices[i] for i in range(topk)])

elif isinstance(query_or_dataset, Dataset):
# Retrieve한 Passage를 pd.DataFrame으로 반환합니다.
total = []
with timer("query exhaustive search"):
doc_scores, doc_indices, doc_responses = self.get_relevant_doc_bulk(
query_or_dataset["question"], k=topk
)

for idx, example in enumerate(tqdm(query_or_dataset, desc="Elasticsearch")):
# retrieved_context 구하는 부분 수정
retrieved_context = []
for i in range(min(topk, len(doc_responses[idx]))):
retrieved_context.append(doc_responses[idx][i]['_source']['text'])

tmp = {
# Query와 해당 id를 반환합니다.
"question": example["question"],
"id": example["id"],
# Retrieve한 Passage의 id, context를 반환합니다.
# "context_id": doc_indices[idx],
"context": " ".join(retrieved_context), # 수정
}
if "context" in example.keys() and "answers" in example.keys():
# validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
tmp["original_context"] = example["context"]
tmp["answers"] = example["answers"]
total.append(tmp)

cqas = pd.DataFrame(total)

return cqas

def get_sparse_embedding(self):
with timer("elastic building..."):
indices = self.elastic_client.get_indices()
print('Elastic indices :', indices)

def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]:
with timer("query ex search"):
responses = self.elastic_client.search('wiki_docs', question=query, topk=k)
doc_score = []
doc_indices = []
for res in responses:
doc_score.append(res['_score'])
doc_indices.append(res['_id'])

return doc_score, doc_indices, responses

def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 10) -> Tuple[List, List]:
with timer("query ex search"):
doc_scores = []
doc_indices = []
doc_responses = []

for query in queries:
doc_score = []
doc_index = []
responses = self.elastic_client.search('wiki_docs', question=query, topk=k)


for res in responses:
doc_score.append(res['_score'])
doc_indices.append(res['_id'])

doc_scores.append(doc_score)
doc_indices.append(doc_index)
doc_responses.append(responses)

return doc_scores, doc_indices, doc_responses


if __name__ == "__main__":

import argparse

parser = argparse.ArgumentParser(description="")
parser.add_argument("--dataset_name", metavar="./data/train_dataset", type=str, help="")
parser.add_argument(
"--dataset_name", default="../data/train_dataset", type=str, help=""
)
parser.add_argument(
"--model_name_or_path",
metavar="bert-base-multilingual-cased",
default="bert-base-multilingual-cased",
type=str,
help="",
)
parser.add_argument("--data_path", metavar="./data", type=str, help="")
parser.add_argument("--context_path", metavar="wikipedia_documents", type=str, help="")
parser.add_argument("--use_faiss", metavar=False, type=bool, help="")
parser.add_argument("--data_path", default="../data", type=str, help="")
parser.add_argument(
"--context_path", default="../data/wikipedia_documents", type=str, help=""
)
parser.add_argument("--use_faiss", default=False, type=bool, help="")

args = parser.parse_args()

# Test sparse
print(args.dataset_name)
org_dataset = load_from_disk(args.dataset_name)
full_ds = concatenate_datasets(
[
Expand All @@ -530,18 +628,7 @@ def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 1) -> Tuple[Li
print("*" * 40, "query dataset", "*" * 40)
print(full_ds)

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
use_fast=False,
)

retriever = SparseRetrieval(
tokenize_fn=tokenizer.tokenize,
data_path=args.data_path,
context_path=args.context_path,
)
retriever = ElasticRetrieval(host='localhost', port='9200')

query = "대통령을 포함한 미국의 행정부 견제권을 갖는 국가 기관은?"

Expand Down
Loading

0 comments on commit 7d8bcfd

Please sign in to comment.