From ce27bb295fa6eebd55cbe359067163b578210b90 Mon Sep 17 00:00:00 2001 From: claudiu_daniel_hromei <38636-claudiu_daniel_hromei@users.noreply.gitlab.aicrowd.com> Date: Mon, 14 Nov 2022 17:38:43 +0100 Subject: [PATCH] added entity retrieval type option --- main.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index ab24000..63105e1 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,10 @@ +from datetime import datetime + +start_time = datetime.now() +printable_start_time = start_time.strftime("%H:%M:%S") +print(f"START {printable_start_time}") +############################################################## + from trainer import Trainer from predictor import Predictor from utils.enums import Language @@ -25,7 +32,7 @@ parser.add_argument('mode', type=str, help='The modality: train or predict') -def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelVariant, batchSize, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding): +def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelVariant, batchSize, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, entityRetrievalType, lexicalReferences, thresholdW2V, thresholdLDIST): global parser # Optional argument @@ -82,8 +89,25 @@ def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelV # Optional argument parser.add_argument('-gr','--grounding', type=str, - help='type of grounding to perform. Acceptable values are "full", "half", "no". Default "' + str(grounding) + '". Define only in train mode.') + help='type of grounding to perform. Acceptable values are "full", "half", "post", "no". Default "' + str(grounding) + '". Define only in train mode.') + + # Optional argument + parser.add_argument('-ert','--entityRetrievalType', type=str, + help='type of entity retrieval. Acceptable values are "STR" for string match or "LDIST" for Levenshtein Distance or "W2V" for word2vec. Default "' + str(entityRetrievalType) + '". Define only in train mode.') + + # Optional argument + parser.add_argument('-lexr','--lexicalReferences', type=str, + help='type of lexical references of entity usage. Acceptable values are "all" for using all lrs or "random" for using 1 random lr. Default "' + str(lexicalReferences) + '". Define only in train mode.') + + # Optional argument + parser.add_argument('-tw2v','--thresholdW2V', type=float, + help='threshold for W2V retrieval. Acceptable values are floats between 0 and 1. Default "' + str(thresholdW2V) + '". Define only in train mode.') + # Optional argument + parser.add_argument('-tLDIST','--thresholdLDIST', type=float, + help='threshold for Levenshtein Distance retrieval. Acceptable values are floats between 0 and 1. Default "' + str(thresholdLDIST) + '". Define only in train mode.') + + def definePredictArguments(model, text): global parser @@ -125,7 +149,7 @@ def main(): defineGlobalArguments(task, num_beams, return_sequences, language) - #train mode + # train mode # huricParsingDir = 'data/huric/en' # datasetFile = 'data/data-huric.csv' n_fold = 2 @@ -142,8 +166,12 @@ def main(): addLUType = False mapType = "nomap" grounding = "no" + entityRetrievalType = "STR" + lexicalReferences = "all" + thresholdW2V = 0.5 + thresholdLDIST = 0.8 - defineTrainArguments(n_fold, use_cuda, epochs, target_type, modelName, modelVariant, batch_size, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding) + defineTrainArguments(n_fold, use_cuda, epochs, target_type, modelName, modelVariant, batch_size, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, entityRetrievalType, lexicalReferences, thresholdW2V, thresholdLDIST) #predict mode model = 'outputs' @@ -200,10 +228,19 @@ def main(): addLUType = args.addLUType if args.learning_rate != None: learning_rate = args.learning_rate + if args.entityRetrievalType != None: + entityRetrievalType = args.entityRetrievalType + if args.lexicalReferences != None: + lexicalReferences = args.lexicalReferences + if args.thresholdW2V != None: + thresholdW2V = args.thresholdW2V + if args.thresholdLDIST != None: + thresholdLDIST = args.thresholdLDIST + trainer = Trainer(language, model=modelName, model_variant=modelVariant, task=task, learning_rate=learning_rate, batch_size=batch_size, use_cuda=use_cuda, num_train_epochs=epochs, target_type=target_type, early_stopping=early_stopping, num_beans=num_beams, return_sequences=return_sequences) print("Training and saving models for all folds!") - trainer.train_saving_all_folds_models(n_fold, quick_train=quick_train, addMap=addMap, map_type=mapType, addLUType=addLUType, grounding=grounding) + trainer.train_saving_all_folds_models(n_fold, quick_train=quick_train, addMap=addMap, map_type=mapType, addLUType=addLUType, grounding=grounding, entityRetrievalType=entityRetrievalType, lexicalReferences=lexicalReferences, thresholdW2V=thresholdW2V, thresholdLDIST=thresholdLDIST) print("TRAIN FINISHED") # predict only options @@ -224,4 +261,13 @@ def main(): print("EXAMPLE: `python main.py train`") if __name__ == "__main__": - main() \ No newline at end of file + main() + + +############################################################## +end_time = datetime.now() +difference_time = end_time - start_time +printable_difference_time = difference_time +printable_end_time = end_time.strftime("%H:%M:%S") +print(f"PASSED {printable_difference_time}") +print(f"ENDED AT {printable_end_time}") \ No newline at end of file