From 82459440168b1fb8833a9d65c81e0b3c3cc98461 Mon Sep 17 00:00:00 2001 From: Hyunsoo Date: Wed, 4 Jan 2023 08:18:49 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=EC=97=98=EB=9D=BC=EC=8A=A4=ED=8B=B1?= =?UTF-8?q?=EC=84=9C=EC=B9=98=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- arguments.py | 2 +- es_retrieval.py | 31 +++++++------ inference.py | 5 ++- retrieval.py | 114 ++++++++++++++++++++++++++++++++++++++++++------ train.py | 15 ++++++- 6 files changed, 135 insertions(+), 35 deletions(-) diff --git a/.gitignore b/.gitignore index 2d30738..00a5db5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ data/* data.tar.gz .ipynb_checkpoints __pycache__ -outputs \ No newline at end of file +outputs +*.out \ No newline at end of file diff --git a/arguments.py b/arguments.py index 6cf939e..4713b0a 100644 --- a/arguments.py +++ b/arguments.py @@ -9,7 +9,7 @@ class ModelArguments: """ model_name_or_path: str = field( - default="klue/bert-base", + default="monologg/koelectra-base-v3-discriminator", metadata={ "help": "Path to pretrained model or model identifier from huggingface.co/models" }, diff --git a/es_retrieval.py b/es_retrieval.py index 7a91d0d..6a5bbe1 100644 --- a/es_retrieval.py +++ b/es_retrieval.py @@ -60,11 +60,11 @@ def create_index(self, index_name: str, setting_path: str = "./settings.json"): self.client.indices.delete(index=index_name) else: - return + return False self.client.indices.create(index=index_name, body=settings) - print(f"Create an Index ({index_name})") + return True def get_indices(self): indices = list(self.client.indices.get_alias().keys()) @@ -86,7 +86,7 @@ def delete_index(self, index_name: str): def insert_data( self, index_name: str, - data_path: str = "../data/deduplication_wikipedia_documents.json", + data_path: str = "../data/wikipedia_documents.json", ): """_summary_ @@ -100,14 +100,12 @@ def insert_data( docs = [] print("Data Loding...") - for k, v in data.items(): + for i, v in enumerate(data.values()): doc = { "_index": index_name, "_type": "_doc", - "_id": k, - "document_id": v["document_id"], + "_id": i, "text": v["text"], - "corpus_source": v["corpus_source"], "title": v["title"], } @@ -129,6 +127,13 @@ def delete_data(self, index_name: str, doc_id): self.client.delete(index=index_name, id=doc_id) print(f"Deleted {doc_id} document.") + + def init_index(self, index_name: str): + if self.client.indices.exists(index=index_name): + self.delete_index(index_name=index_name) + + self.create_index(index_name=index_name) + print(f"Initialization...({index_name})") def document_count(self, index_name: str): @@ -139,15 +144,9 @@ def search(self, index_name: str, question: str, topk: int = 10): body = {"query": {"bool": {"must": [{"match": {"text": question}}]}}} - responses = self.client.search(index=index_name, body=body, size=topk)["hits"][ - "hits" - ] - outputs = [ - {"text": res["_source"]["text"], "score": res["_score"]} - for res in responses - ] + responses = self.client.search(index=index_name, body=body, size=topk)["hits"]["hits"] - return outputs + return responses if __name__ == "__main__": @@ -158,7 +157,7 @@ def search(self, index_name: str, question: str, topk: int = 10): es.delete_index("wiki_docs") es.create_index("wiki_docs") es.insert_data("wiki_docs") - es.document_count("wiki_docs") + print(es.document_count("wiki_docs")) outputs = es.search("wiki_docs", "소백산맥의 동남부에 위치한 지역은?") diff --git a/inference.py b/inference.py index e1d7444..08bdf72 100644 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ load_from_disk, load_metric, ) -from retrieval import TfidfRetrieval,BM25 +from retrieval import TfidfRetrieval,BM25,ElasticRetrieval from trainer_qa import QuestionAnsweringTrainer from transformers import ( AutoConfig, @@ -113,6 +113,9 @@ def run_sparse_retrieval( retriever = TfidfRetrieval( tokenize_fn=tokenize_fn, data_path=data_path, context_path=context_path ) + + elif data_args.retrieval_choice=="elastic": + retriever = ElasticRetrieval(host='localhost', port='9200') retriever.get_sparse_embedding() if data_args.use_faiss: diff --git a/retrieval.py b/retrieval.py index 5a3e26f..87dd255 100644 --- a/retrieval.py +++ b/retrieval.py @@ -12,6 +12,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer from tqdm.auto import tqdm from rank_bm25 import BM25Okapi +from es_retrieval import ElasticObject @contextmanager def timer(name): @@ -519,30 +520,123 @@ def get_relevant_doc_bulk( return doc_scores, doc_indices + +class ElasticRetrieval: + def __init__(self, host='localhost', port='9200') -> NoReturn: + self.host = host + self.port = port + self.elastic_client = ElasticObject(host=self.host, port=self.port) + + def retrieve( + self, query_or_dataset: Union[str, Dataset], topk: Optional[int] = 1 + ) -> Union[Tuple[List, List], pd.DataFrame]: + + if isinstance(query_or_dataset, str): + doc_scores, doc_indices, responses = self.get_relevant_doc(query_or_dataset, k=topk) + print("[Search query]\n", query_or_dataset, "\n") + + for i in range(min(topk, len(responses))): + print(f"Top-{i+1} passage with score {doc_scores[i]:4f}") + print(doc_indices[i]) + print(responses[i]['_source']['text']) + + return (doc_scores, [doc_indices[i] for i in range(topk)]) + + elif isinstance(query_or_dataset, Dataset): + # Retrieve한 Passage를 pd.DataFrame으로 반환합니다. + total = [] + with timer("query exhaustive search"): + doc_scores, doc_indices, doc_responses = self.get_relevant_doc_bulk( + query_or_dataset["question"], k=topk + ) + + for idx, example in enumerate(tqdm(query_or_dataset, desc="Elasticsearch")): + # retrieved_context 구하는 부분 수정 + retrieved_context = [] + for i in range(min(topk, len(doc_responses[idx]))): + retrieved_context.append(doc_responses[idx][i]['_source']['text']) + + tmp = { + # Query와 해당 id를 반환합니다. + "question": example["question"], + "id": example["id"], + # Retrieve한 Passage의 id, context를 반환합니다. + "context_id": doc_indices[idx], + "context": " ".join(retrieved_context), # 수정 + } + if "context" in example.keys() and "answers" in example.keys(): + # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다. + tmp["original_context"] = example["context"] + tmp["answers"] = example["answers"] + total.append(tmp) + + cqas = pd.DataFrame(total) + + return cqas + + def get_sparse_embedding(self): + with timer("elastic building..."): + indices = self.elastic_client.get_indices('wiki_docs') + print('Elastic indices :', indices) + + def get_relevant_doc(self, query: str, k: Optional[int] = 1) -> Tuple[List, List]: + with timer("query ex search"): + responses = self.elastic_client.search('wiki_docs', question=query, topk=k) + doc_score = [] + doc_indices = [] + for res in responses: + doc_score.append(res['_score']) + doc_indices.append(res['_id']) + + return doc_score, doc_indices, responses + def get_relevant_doc_bulk(self, queries: List, k: Optional[int] = 10) -> Tuple[List, List]: + with timer("query ex search"): + doc_scores = [] + doc_indices = [] + doc_responses = [] + + for query in queries: + doc_score = [] + doc_index = [] + responses = self.elastic_client.search('wiki_docs', question=query, topk=k) + + + for res in responses: + doc_score.append(res['_score']) + doc_indices.append(res['_id']) + + doc_scores.append(doc_score) + doc_indices.append(doc_index) + doc_responses.append(responses) + + return doc_scores, doc_indices, doc_responses + + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="") parser.add_argument( - "--dataset_name", metavar="./data/train_dataset", type=str, help="" + "--dataset_name", default="../data/train_dataset", type=str, help="" ) parser.add_argument( "--model_name_or_path", - metavar="bert-base-multilingual-cased", + default="bert-base-multilingual-cased", type=str, help="", ) - parser.add_argument("--data_path", metavar="./data", type=str, help="") + parser.add_argument("--data_path", default="../data", type=str, help="") parser.add_argument( - "--context_path", metavar="wikipedia_documents", type=str, help="" + "--context_path", default="../data/wikipedia_documents", type=str, help="" ) - parser.add_argument("--use_faiss", metavar=False, type=bool, help="") + parser.add_argument("--use_faiss", default=False, type=bool, help="") args = parser.parse_args() # Test sparse + print(args.dataset_name) org_dataset = load_from_disk(args.dataset_name) full_ds = concatenate_datasets( [ @@ -553,15 +647,7 @@ def get_relevant_doc_bulk( print("*" * 40, "query dataset", "*" * 40) print(full_ds) - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False,) - - retriever = SparseRetrieval( - tokenize_fn=tokenizer.tokenize, - data_path=args.data_path, - context_path=args.context_path, - ) + retriever = ElasticRetrieval(host='localhost', port='9200') query = "대통령을 포함한 미국의 행정부 견제권을 갖는 국가 기관은?" diff --git a/train.py b/train.py index e0e9625..0c790d7 100644 --- a/train.py +++ b/train.py @@ -17,6 +17,7 @@ set_seed, ) from utils_qa import check_no_error, postprocess_qa_predictions +import wandb logger = logging.getLogger(__name__) @@ -29,11 +30,20 @@ def main(): (ModelArguments, DataTrainingArguments, TrainingArguments) ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() + wandb.init( + project="Hyunsoo", + name="ELECTRA-Aug-Train", + entity="nlp-08-mrc", + config=training_args, + ) + print(model_args.model_name_or_path) + print(training_args) # [참고] argument를 manual하게 수정하고 싶은 경우에 아래와 같은 방식을 사용할 수 있습니다 - # training_args.per_device_train_batch_size = 4 - # print(training_args.per_device_train_batch_size) + training_args.save_total_limit = 2 + training_args.report_to = ["wandb"] + training_args.per_device_train_batch_size = 32 print(f"model is from {model_args.model_name_or_path}") print(f"data is from {data_args.dataset_name}") @@ -79,6 +89,7 @@ def main(): from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, ) + wandb.watch(model) print(