Skip to content

Commit

Permalink
fixes for grut2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
claudiu_daniel_hromei committed Nov 22, 2022
1 parent e9112f3 commit f769ae7
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 106 deletions.
69 changes: 42 additions & 27 deletions main.py → grut.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from utils.enums import Language
import argparse
import logging
from huricParser import HuricParser

###############################################################
# IMPORTANT IMPORTS and OPTIONS
Expand All @@ -32,7 +33,7 @@
parser.add_argument('mode', type=str,
help='The modality: train or predict')

def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelVariant, batchSize, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, entityRetrievalType, lexicalReferences, thresholdW2V, thresholdLDIST):
def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelVariant, batchSize, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, lexicalReferences, thresholdW2V, thresholdLDIST):
global parser

# Optional argument
Expand Down Expand Up @@ -91,37 +92,37 @@ def defineTrainArguments(n_fold, use_cuda, epochs, targetType, modelName, modelV
parser.add_argument('-gr','--grounding', type=str,
help='type of grounding to perform. Acceptable values are "full", "half", "post", "no". Default "' + str(grounding) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-ert','--entityRetrievalType', type=str,
help='type of entity retrieval. Acceptable values are "STR" for string match or "LDIST" for Levenshtein Distance or "W2V" for word2vec. Default "' + str(entityRetrievalType) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-lexr','--lexicalReferences', type=str,
help='type of lexical references of entity usage. Acceptable values are "all" for using all lrs or "random" for using 1 random lr. Default "' + str(lexicalReferences) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-tw2v','--thresholdW2V', type=float,
help='threshold for W2V retrieval. Acceptable values are floats between 0 and 1. Default "' + str(thresholdW2V) + '". Define only in train mode.')

# Optional argument
parser.add_argument('-tLDIST','--thresholdLDIST', type=float,
help='threshold for Levenshtein Distance retrieval. Acceptable values are floats between 0 and 1. Default "' + str(thresholdLDIST) + '". Define only in train mode.')



def definePredictArguments(model, text):
def definePredictArguments(model_dir, text, huric_file_path):
global parser

# Optional argument
parser.add_argument('-m', '--model', type=str,
help='path to model. Default "' + model + '". Define only in predict mode.')
parser.add_argument('-m', '--model_dir', type=str,
help='path to model_dir. Default "' + model_dir + '". Define only in predict mode.')

# Optional argument
parser.add_argument('-i','--input', type=str,
help='input text. Default "' + text + '". Define only in predict mode.')

# Optional argument
parser.add_argument('-hrc','--huric_file_path', type=str,
help='path to huric file. Default "' + str(huric_file_path) + '". Define only in predict mode.')


def defineGlobalArguments(task, num_beams, return_sequences, language):
def defineGlobalArguments(task, num_beams, return_sequences, language, entityRetrievalType):
global parser

# Optional argument
Expand All @@ -139,19 +140,24 @@ def defineGlobalArguments(task, num_beams, return_sequences, language):
parser.add_argument('-lan','--language', type=str,
help='dataset language to use. Default "' + language.value + '". Define both in train and predict mode.')

# Optional argument
parser.add_argument('-ert','--entityRetrievalType', type=str,
help='type of entity retrieval. Acceptable values are "STR" for string match or "LDIST" for Levenshtein Distance or "W2V" for word2vec. Default "' + str(entityRetrievalType) + '". Define only in train mode.')



def main():

# both modes
num_beams = None
return_sequences = 1
task = 'SRL'
language = Language.ENGLISH
entityRetrievalType = "STR"

defineGlobalArguments(task, num_beams, return_sequences, language)
defineGlobalArguments(task, num_beams, return_sequences, language, entityRetrievalType)

# train mode
# huricParsingDir = 'data/huric/en'
# datasetFile = 'data/data-huric.csv'
n_fold = 2
use_cuda = False
epochs = 1
Expand All @@ -166,20 +172,21 @@ def main():
addLUType = False
mapType = "nomap"
grounding = "no"
entityRetrievalType = "STR"
lexicalReferences = "all"
thresholdW2V = 0.5
thresholdLDIST = 0.8

defineTrainArguments(n_fold, use_cuda, epochs, target_type, modelName, modelVariant, batch_size, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, entityRetrievalType, lexicalReferences, thresholdW2V, thresholdLDIST)
defineTrainArguments(n_fold, use_cuda, epochs, target_type, modelName, modelVariant, batch_size, learning_rate, early_stopping, quick_train, addMap, addLUType, mapType, grounding, lexicalReferences, thresholdW2V, thresholdLDIST)

#predict mode
model = 'outputs'
model_dir = '/model'
text = "take the book near the cat on the sofa"
# huric_file_path = "/data/huric/en/S4R/2748.hrc"
huric_file_path = None

definePredictArguments(model, text)
definePredictArguments(model_dir, text, huric_file_path)

#get arguments from command line
# get arguments from command line
args = parser.parse_args()

# both modes
Expand Down Expand Up @@ -232,10 +239,6 @@ def main():
entityRetrievalType = args.entityRetrievalType
if args.lexicalReferences != None:
lexicalReferences = args.lexicalReferences
if args.thresholdW2V != None:
thresholdW2V = args.thresholdW2V
if args.thresholdLDIST != None:
thresholdLDIST = args.thresholdLDIST


trainer = Trainer(language, model=modelName, model_variant=modelVariant, task=task, learning_rate=learning_rate, batch_size=batch_size, use_cuda=use_cuda, num_train_epochs=epochs, target_type=target_type, early_stopping=early_stopping, num_beans=num_beams, return_sequences=return_sequences)
Expand All @@ -246,12 +249,24 @@ def main():
# predict only options
elif args.mode == 'predict':
print('Starting predict mode...')
if args.model != None:
model = args.model
if args.model_dir != None:
model_dir = args.model_dir
if args.input != None:
text = args.input

predictor = Predictor(model, num_beans=num_beams, return_sequences=return_sequences)
if args.huric_file_path != None:
huric_file_path = args.huric_file_path
if args.entityRetrievalType != None:
entityRetrievalType = args.entityRetrievalType
addMap = True
hp = HuricParser(Language.ENGLISH)
[_, sentence, _], _ = hp.parseHuricFile(huric_file_path, task, "", addMap, noMap=False, map_type="lmd", addLUType=False, grounding="yes", entityRetrievalType=entityRetrievalType)
text += " # " + sentence.split(" # ")[1]
print("File parsed correctly!")
print(f"The sentence with the map is: '{text}'")
else:
print("Huric file path not given, no map added (maybe it's already appended to the input?)")

predictor = Predictor(model_dir=model_dir, num_beans=num_beams, return_sequences=return_sequences)
result = predictor.predict(task, text)
print(result)

Expand Down
53 changes: 27 additions & 26 deletions predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from simpletransformers.t5 import T5Model, T5Args

class Predictor:
def __init__(self, model = 'bart', model_dir = 'outputs/best_model', num_beans = 1, return_sequences = 1, use_cuda = False):
def __init__(self, model_name = 'bart', model_dir = 'outputs/best_model', num_beans = 1, return_sequences = 1, use_cuda = False):


self.encoder_decoder_name = model_dir
Expand All @@ -12,31 +12,32 @@ def __init__(self, model = 'bart', model_dir = 'outputs/best_model', num_beans =
cuda_available = torch.cuda.is_available() and self.use_cuda
print('GPU available: ' + str(cuda_available))

if model == 'mt5':
self.model_args = T5Args()
self.model_args.num_beams = num_beans
self.model_args.num_return_sequences = return_sequences

self.model = T5Model(
model_type="mt5",
model_name=self.encoder_decoder_name,
args=self.model_args,
use_cuda=cuda_available
)
elif model == 'bart':
self.model_args = Seq2SeqArgs()
self.model_args.num_beams = num_beans
self.model_args.num_return_sequences = return_sequences

self.model = Seq2SeqModel(
encoder_decoder_type="bart",
encoder_decoder_name=self.encoder_decoder_name,
args=self.model_args,
use_cuda=cuda_available
)
else:
print("ERROR")
print("Only bart or mt5 models supported for now!")
# if model_name == 'mt5':
# self.model_args = T5Args()
# self.model_args.num_beams = num_beans
# self.model_args.num_return_sequences = return_sequences

# self.model = T5Model(
# model_type="mt5",
# model_name=self.encoder_decoder_name,
# args=self.model_args,
# use_cuda=cuda_available
# )
# elif model_name == 'bart':
self.model_args = Seq2SeqArgs()
self.model_args.num_beams = num_beans
self.model_args.num_return_sequences = return_sequences

self.model = Seq2SeqModel(
encoder_decoder_type="bart",
encoder_decoder_name=self.encoder_decoder_name,
args=self.model_args,
use_cuda=cuda_available
)
# else:
# print("ERROR")
# print("Only bart or mt5 models supported for now!")


def predict(self, task, input):

Expand Down
Loading

0 comments on commit f769ae7

Please sign in to comment.