Skip to content

Commit

Permalink
feat: 엘라스틱서치 수정
Browse files Browse the repository at this point in the history
  • Loading branch information
khs0415p committed Jan 4, 2023
1 parent c321612 commit 8245944
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 35 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ data/*
data.tar.gz
.ipynb_checkpoints
__pycache__
outputs
outputs
*.out
2 changes: 1 addition & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class ModelArguments:
"""

model_name_or_path: str = field(
default="klue/bert-base",
default="monologg/koelectra-base-v3-discriminator",
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
},
Expand Down
31 changes: 15 additions & 16 deletions es_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def create_index(self, index_name: str, setting_path: str = "./settings.json"):
self.client.indices.delete(index=index_name)

else:
return
return False

self.client.indices.create(index=index_name, body=settings)

print(f"Create an Index ({index_name})")
return True

def get_indices(self):
indices = list(self.client.indices.get_alias().keys())
Expand All @@ -86,7 +86,7 @@ def delete_index(self, index_name: str):
def insert_data(
self,
index_name: str,
data_path: str = "../data/deduplication_wikipedia_documents.json",
data_path: str = "../data/wikipedia_documents.json",
):
"""_summary_
Expand All @@ -100,14 +100,12 @@ def insert_data(

docs = []
print("Data Loding...")
for k, v in data.items():
for i, v in enumerate(data.values()):
doc = {
"_index": index_name,
"_type": "_doc",
"_id": k,
"document_id": v["document_id"],
"_id": i,
"text": v["text"],
"corpus_source": v["corpus_source"],
"title": v["title"],
}

Expand All @@ -129,6 +127,13 @@ def delete_data(self, index_name: str, doc_id):
self.client.delete(index=index_name, id=doc_id)

print(f"Deleted {doc_id} document.")

def init_index(self, index_name: str):
if self.client.indices.exists(index=index_name):
self.delete_index(index_name=index_name)

self.create_index(index_name=index_name)
print(f"Initialization...({index_name})")

def document_count(self, index_name: str):

Expand All @@ -139,15 +144,9 @@ def search(self, index_name: str, question: str, topk: int = 10):

body = {"query": {"bool": {"must": [{"match": {"text": question}}]}}}

responses = self.client.search(index=index_name, body=body, size=topk)["hits"][
"hits"
]
outputs = [
{"text": res["_source"]["text"], "score": res["_score"]}
for res in responses
]
responses = self.client.search(index=index_name, body=body, size=topk)["hits"]["hits"]

return outputs
return responses


if __name__ == "__main__":
Expand All @@ -158,7 +157,7 @@ def search(self, index_name: str, question: str, topk: int = 10):
es.delete_index("wiki_docs")
es.create_index("wiki_docs")
es.insert_data("wiki_docs")
es.document_count("wiki_docs")
print(es.document_count("wiki_docs"))

outputs = es.search("wiki_docs", "소백산맥의 동남부에 위치한 지역은?")

Expand Down
5 changes: 4 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
load_from_disk,
load_metric,
)
from retrieval import TfidfRetrieval,BM25
from retrieval import TfidfRetrieval,BM25,ElasticRetrieval
from trainer_qa import QuestionAnsweringTrainer
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -113,6 +113,9 @@ def run_sparse_retrieval(
retriever = TfidfRetrieval(
tokenize_fn=tokenize_fn, data_path=data_path, context_path=context_path
)

elif data_args.retrieval_choice=="elastic":
retriever = ElasticRetrieval(host='localhost', port='9200')
retriever.get_sparse_embedding()

if data_args.use_faiss:
Expand Down
114 changes: 100 additions & 14 deletions retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.auto import tqdm
from rank_bm25 import BM25Okapi
from es_retrieval import ElasticObject

@contextmanager
def timer(name):
Expand Down Expand Up @@ -519,30 +520,123 @@ def get_relevant_doc_bulk(
return doc_scores, doc_indices



class ElasticRetrieval:
def __init__(self, host='localhost', port='9200') -> NoReturn:
self.host = host
self.port = port
self.elastic_client = ElasticObject(host=self.host, port=self.port)

def retrieve(
self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1
) -> Union[Tuple[List, List], pd.DataFrame]:

if isinstance(query_or_dataset, str):
doc_scores, doc_indices, responses = self.get_relevant_doc(query_or_dataset, k=topk)
print("[Search query]\n", query_or_dataset, "\n")

for i in range(min(topk, len(responses))):
print(f"Top-{i+1} passage with score {doc_scores[i]:4f}")
print(doc_indices[i])
print(responses[i]['_source']['text'])

return (doc_scores, [doc_indices[i] for i in range(topk)])

elif isinstance(query_or_dataset, Dataset):
# Retrieve한 Passage를 pd.DataFrame으로 반환합니다.
total = []
with timer("query exhaustive search"):
doc_scores, doc_indices, doc_responses = self.get_relevant_doc_bulk(
query_or_dataset["question"], k=topk
)

for idx, example in enumerate(tqdm(query_or_dataset, desc="Elasticsearch")):
# retrieved_context 구하는 부분 수정
retrieved_context = []
for i in range(min(topk, len(doc_responses[idx]))):
retrieved_context.append(doc_responses[idx][i]['_source']['text'])

tmp = {
# Query와 해당 id를 반환합니다.
"question": example["question"],
"id": example["id"],
# Retrieve한 Passage의 id, context를 반환합니다.
"context_id": doc_indices[idx],
"context": " ".join(retrieved_context), # 수정
}
if "context" in example.keys() and "answers" in example.keys():
# validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
tmp["original_context"] = example["context"]
tmp["answers"] = example["answers"]
total.append(tmp)

cqas = pd.DataFrame(total)

return cqas

def get_sparse_embedding(self):
with timer("elastic building..."):
indices = self.elastic_client.get_indices('wiki_docs')
print('Elastic indices :', indices)

def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]:
with timer("query ex search"):
responses = self.elastic_client.search('wiki_docs', question=query, topk=k)
doc_score = []
doc_indices = []
for res in responses:
doc_score.append(res['_score'])
doc_indices.append(res['_id'])

return doc_score, doc_indices, responses

def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 10) -> Tuple[List, List]:
with timer("query ex search"):
doc_scores = []
doc_indices = []
doc_responses = []

for query in queries:
doc_score = []
doc_index = []
responses = self.elastic_client.search('wiki_docs', question=query, topk=k)


for res in responses:
doc_score.append(res['_score'])
doc_indices.append(res['_id'])

doc_scores.append(doc_score)
doc_indices.append(doc_index)
doc_responses.append(responses)

return doc_scores, doc_indices, doc_responses


if __name__ == "__main__":

import argparse

parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--dataset_name", metavar="./data/train_dataset", type=str, help=""
"--dataset_name", default="../data/train_dataset", type=str, help=""
)
parser.add_argument(
"--model_name_or_path",
metavar="bert-base-multilingual-cased",
default="bert-base-multilingual-cased",
type=str,
help="",
)
parser.add_argument("--data_path", metavar="./data", type=str, help="")
parser.add_argument("--data_path", default="../data", type=str, help="")
parser.add_argument(
"--context_path", metavar="wikipedia_documents", type=str, help=""
"--context_path", default="../data/wikipedia_documents", type=str, help=""
)
parser.add_argument("--use_faiss", metavar=False, type=bool, help="")
parser.add_argument("--use_faiss", default=False, type=bool, help="")

args = parser.parse_args()

# Test sparse
print(args.dataset_name)
org_dataset = load_from_disk(args.dataset_name)
full_ds = concatenate_datasets(
[
Expand All @@ -553,15 +647,7 @@ def get_relevant_doc_bulk(
print("*" * 40, "query dataset", "*" * 40)
print(full_ds)

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False,)

retriever = SparseRetrieval(
tokenize_fn=tokenizer.tokenize,
data_path=args.data_path,
context_path=args.context_path,
)
retriever = ElasticRetrieval(host='localhost', port='9200')

query = "대통령을 포함한 미국의 행정부 견제권을 갖는 국가 기관은?"

Expand Down
15 changes: 13 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
set_seed,
)
from utils_qa import check_no_error, postprocess_qa_predictions
import wandb

logger = logging.getLogger(__name__)

Expand All @@ -29,11 +30,20 @@ def main():
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
wandb.init(
project="Hyunsoo",
name="ELECTRA-Aug-Train",
entity="nlp-08-mrc",
config=training_args,
)

print(model_args.model_name_or_path)
print(training_args)

# [참고] argument를 manual하게 수정하고 싶은 경우에 아래와 같은 방식을 사용할 수 있습니다
# training_args.per_device_train_batch_size = 4
# print(training_args.per_device_train_batch_size)
training_args.save_total_limit = 2
training_args.report_to = ["wandb"]
training_args.per_device_train_batch_size = 32

print(f"model is from {model_args.model_name_or_path}")
print(f"data is from {data_args.dataset_name}")
Expand Down Expand Up @@ -79,6 +89,7 @@ def main():
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
)
wandb.watch(model)

print(

Expand Down

0 comments on commit 8245944

Please sign in to comment.