Skip to content

Commit

Permalink
added entity retrieval type option
Browse files Browse the repository at this point in the history
  • Loading branch information
claudiu_daniel_hromei committed Nov 14, 2022
1 parent 9fe0e75 commit ce27bb2
Showing 1 changed file with 52 additions and 6 deletions.
58 changes: 52 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from datetime import datetime

start_time = datetime.now()
printable_start_time = start_time.strftime("%H:%M:%S")
print(f"START {printable_start_time}")
##############################################################

from trainer import Trainer
from predictor import Predictor
from utils.enums import Language
Expand Down Expand Up @@ -25,7 +32,7 @@
parser.add_argument('mode', type=str,
help='The modality: train or predict')

def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelVariant, batchSize, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding):
def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelVariant, batchSize, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, entityRetrievalType, lexicalReferences, thresholdW2V, thresholdLDIST):
global parser

# Optional argument
Expand Down Expand Up @@ -82,8 +89,25 @@ def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelV

# Optional argument
parser.add_argument('-gr','--grounding', type=str,
help='type of grounding to perform. Acceptable values are "full", "half", "no". Default "' + str(grounding) + '". Define only in train mode.')
help='type of grounding to perform. Acceptable values are "full", "half", "post", "no". Default "' + str(grounding) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-ert','--entityRetrievalType', type=str,
help='type of entity retrieval. Acceptable values are "STR" for string match or "LDIST" for Levenshtein Distance or "W2V" for word2vec. Default "' + str(entityRetrievalType) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-lexr','--lexicalReferences', type=str,
help='type of lexical references of entity usage. Acceptable values are "all" for using all lrs or "random" for using 1 random lr. Default "' + str(lexicalReferences) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-tw2v','--thresholdW2V', type=float,
help='threshold for W2V retrieval. Acceptable values are floats between 0 and 1. Default "' + str(thresholdW2V) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-tLDIST','--thresholdLDIST', type=float,
help='threshold for Levenshtein Distance retrieval. Acceptable values are floats between 0 and 1. Default "' + str(thresholdLDIST) + '". Define only in train mode.')



def definePredictArguments(model, text):
global parser
Expand Down Expand Up @@ -125,7 +149,7 @@ def main():

defineGlobalArguments(task, num_beams, return_sequences, language)

#train mode
# train mode
# huricParsingDir = 'data/huric/en'
# datasetFile = 'data/data-huric.csv'
n_fold = 2
Expand All @@ -142,8 +166,12 @@ def main():
addLUType = False
mapType = "nomap"
grounding = "no"
entityRetrievalType = "STR"
lexicalReferences = "all"
thresholdW2V = 0.5
thresholdLDIST = 0.8

defineTrainArguments(n_fold, use_cuda, epochs, target_type, modelName, modelVariant, batch_size, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding)
defineTrainArguments(n_fold, use_cuda, epochs, target_type, modelName, modelVariant, batch_size, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, entityRetrievalType, lexicalReferences, thresholdW2V, thresholdLDIST)

#predict mode
model = 'outputs'
Expand Down Expand Up @@ -200,10 +228,19 @@ def main():
addLUType = args.addLUType
if args.learning_rate != None:
learning_rate = args.learning_rate
if args.entityRetrievalType != None:
entityRetrievalType = args.entityRetrievalType
if args.lexicalReferences != None:
lexicalReferences = args.lexicalReferences
if args.thresholdW2V != None:
thresholdW2V = args.thresholdW2V
if args.thresholdLDIST != None:
thresholdLDIST = args.thresholdLDIST


trainer = Trainer(language, model=modelName, model_variant=modelVariant, task=task, learning_rate=learning_rate, batch_size=batch_size, use_cuda=use_cuda, num_train_epochs=epochs, target_type=target_type, early_stopping=early_stopping, num_beans=num_beams, return_sequences=return_sequences)
print("Training and saving models for all folds!")
trainer.train_saving_all_folds_models(n_fold, quick_train=quick_train, addMap=addMap, map_type=mapType, addLUType=addLUType, grounding=grounding)
trainer.train_saving_all_folds_models(n_fold, quick_train=quick_train, addMap=addMap, map_type=mapType, addLUType=addLUType, grounding=grounding, entityRetrievalType=entityRetrievalType, lexicalReferences=lexicalReferences, thresholdW2V=thresholdW2V, thresholdLDIST=thresholdLDIST)
print("TRAIN FINISHED")

# predict only options
Expand All @@ -224,4 +261,13 @@ def main():
print("EXAMPLE: `python main.py train`")

if __name__ == "__main__":
main()
main()


##############################################################
end_time = datetime.now()
difference_time = end_time - start_time
printable_difference_time = difference_time
printable_end_time = end_time.strftime("%H:%M:%S")
print(f"PASSED {printable_difference_time}")
print(f"ENDED AT {printable_end_time}")

0 comments on commit ce27bb2

Please sign in to comment.