Skip to content

Commit 2a6db66

Browse files
committed
bool arguments fix
1 parent 889f978 commit 2a6db66

File tree

10 files changed

+45
-82
lines changed

10 files changed

+45
-82
lines changed

examples/answerability_detection/answerability_detection_msmarco.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,11 @@
141141
" --eval_batch_size 16 \\\n",
142142
" --grad_accumulation_steps 2 \\\n",
143143
" --log_per_updates 250 \\\n",
144-
" --save_per_updates 16000 \\\n",
145-
" --eval_while_train True \\\n",
146-
" --test_while_train True \\\n",
147144
" --max_seq_len 324 \\\n",
148-
" --silent True "
145+
" --save_per_updates 16000 \\\n",
146+
" --eval_while_train \\\n",
147+
" --test_while_train \\\n",
148+
" --silent"
149149
]
150150
},
151151
{

examples/entailment_detection/entailment_snli.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@
139139
" --eval_batch_size 64 \\\n",
140140
" --grad_accumulation_steps 1 \\\n",
141141
" --log_per_updates 100 \\\n",
142-
" --eval_while_train True \\\n",
143-
" --test_while_train True \\\n",
144142
" --max_seq_len 128 \\\n",
145-
" --silent True "
143+
" --eval_while_train \\\n",
144+
" --test_while_train \\\n",
145+
" --silent"
146146
]
147147
},
148148
{

examples/intent_ner_fragment/intent_ner_fragment.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@
168168
"metadata": {},
169169
"outputs": [],
170170
"source": [
171-
"\n",
172171
"!python ../../train.py \\\n",
173172
" --data_dir '../../data/bert-base-uncased_prepared_data' \\\n",
174173
" --task_file 'tasks_file_snips.yml' \\\n",
@@ -178,10 +177,10 @@
178177
" --eval_batch_size 32 \\\n",
179178
" --grad_accumulation_steps 2 \\\n",
180179
" --log_per_updates 50 \\\n",
181-
" --eval_while_train True \\\n",
182-
" --test_while_train True \\\n",
183180
" --max_seq_len 50 \\\n",
184-
" --silent True "
181+
" --eval_while_train \\\n",
182+
" --test_while_train \\\n",
183+
" --silent "
185184
]
186185
},
187186
{

examples/ner_pos_tagging/ner_pos_tagging_conll.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@
5858
" --eval_batch_size 32 \\\n",
5959
" --grad_accumulation_steps 1 \\\n",
6060
" --log_per_updates 50 \\\n",
61-
" --eval_while_train True \\\n",
62-
" --test_while_train True \\\n",
6361
" --max_seq_len 50 \\\n",
64-
" --silent True "
62+
" --eval_while_train \\\n",
63+
" --test_while_train \\\n",
64+
" --silent"
6565
]
6666
},
6767
{

examples/query_correctness/query_correctness.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@
5858
" --eval_batch_size 32 \\\n",
5959
" --grad_accumulation_steps 1 \\\n",
6060
" --log_per_updates 20 \\\n",
61-
" --eval_while_train True \\\n",
62-
" --test_while_train True \\\n",
6361
" --max_seq_len 50 \\\n",
64-
" --silent True"
62+
" --eval_while_train \\\n",
63+
" --test_while_train \\\n",
64+
" --silent"
6565
]
6666
},
6767
{

examples/query_pair_similarity/query_similarity_qqp.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@
9696
" --eval_batch_size 16 \\\n",
9797
" --grad_accumulation_steps 2 \\\n",
9898
" --log_per_updates 50 \\\n",
99-
" --eval_while_train True \\\n",
100-
" --test_while_train True \\\n",
10199
" --max_seq_len 200 \\\n",
100+
" --eval_while_train \\\n",
101+
" --test_while_train \\\n",
102102
" --silent True "
103103
]
104104
},

examples/query_type_detection/query_type_detection.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@
5959
" --eval_batch_size 64 \\\n",
6060
" --grad_accumulation_steps 1 \\\n",
6161
" --log_per_updates 100 \\\n",
62-
" --eval_while_train True \\\n",
63-
" --test_while_train True \\\n",
6462
" --max_seq_len 60 \\\n",
65-
" --silent True"
63+
" --eval_while_train \\\n",
64+
" --test_while_train \\\n",
65+
" --silent"
6666
]
6767
},
6868
{

models/model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,8 @@ def update_step(self, batchMetaData, batchData):
225225
if self.lossClassList[taskId] and (target is not None):
226226
self.taskLoss = self.lossClassList[taskId](logits, target, attnMasks=modelInputs[2])
227227
#tensorboard details
228-
if self.params['tensorboard']:
229-
self.tbTaskId = taskId
230-
self.tbTaskLoss = self.taskLoss.item()
228+
self.tbTaskId = taskId
229+
self.tbTaskLoss = self.taskLoss.item()
231230
taskLoss = self.taskLoss / self.params['grad_accumulation_steps']
232231
taskLoss.backward()
233232
self.accumulatedStep += 1

train.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,51 +29,48 @@ def make_arguments(parser):
2929
help = 'path to save the model')
3030
parser.add_argument('--epochs', type = int, required=True,
3131
help = 'number of epochs to train')
32-
parser.add_argument('--finetune', type = bool, default= False,
33-
help = "If only the shared model is to be loaded with saved pre-trained multi-task model.\
34-
In this case, you can specify your own tasks with task file and use the pre-trained shared model\
35-
to finetune upon.")
36-
parser.add_argument('--freeze_shared_model', type = bool, default=False,
32+
parser.add_argument('--freeze_shared_model', default=False, action='store_true',
3733
help = "True to freeze the loaded pre-trained shared model and only finetune task specific headers")
3834
parser.add_argument('--train_batch_size', type = int, default=8,
3935
help='batch size to use for training')
4036
parser.add_argument('--eval_batch_size', type = int, default = 32,
4137
help = "batch size to use during evaluation")
42-
parser.add_argument('--eval_while_train', type = bool, default= True,
43-
help = "if evaluation on dev set is required during training.")
44-
parser.add_argument('--test_while_train', type = bool, default = True,
45-
help = "if evaluation on test set is required during training.")
4638
parser.add_argument('--grad_accumulation_steps', type =int, default = 1,
4739
help = "number of steps to accumulate gradients before update")
4840
parser.add_argument('--num_of_warmup_steps', type=int, default = 0,
4941
help = "warm-up value for scheduler")
5042
parser.add_argument('--grad_clip_value', type = float, default=1.0,
5143
help = "gradient clipping value to avoid gradient overflowing" )
52-
parser.add_argument('--debug_mode', default = False, type = bool,
53-
help = "record logs for debugging if True")
5444
parser.add_argument('--log_file', default='multi_task_logs.log', type = str,
5545
help = "name of log file to store")
5646
parser.add_argument('--log_per_updates', default = 10, type = int,
5747
help = "number of steps after which to log loss")
58-
parser.add_argument('--silent', type = bool, default = True,
59-
help = "Only write logs to file if True")
6048
parser.add_argument('--seed', default=42, type = int,
6149
help = "seed to set for modules")
6250
parser.add_argument('--max_seq_len', default=128, type =int,
6351
help = "max seq length used for model at time of data preparation")
64-
parser.add_argument('--tensorboard', default=True, type = bool,
65-
help = "To create tensorboard logs")
6652
parser.add_argument('--save_per_updates', default = 0, type = int,
6753
help = "to keep saving model after this number of updates")
6854
parser.add_argument('--limit_save', default = 10, type = int,
6955
help = "max number recent checkpoints to keep saved")
7056
parser.add_argument('--load_saved_model', type=str, default=None,
7157
help="path to the saved model in case of loading from saved")
72-
parser.add_argument('--resume_train', type=bool, default=False,
73-
help="True for resuming training from a saved model")
58+
parser.add_argument('--eval_while_train', default = False, action = 'store_true',
59+
help = "if evaluation on dev set is required during training.")
60+
parser.add_argument('--test_while_train', default=False, action = 'store_true',
61+
help = "if evaluation on test set is required during training.")
62+
parser.add_argument('--resume_train', default=False, action = 'store_true',
63+
help="Set for resuming training from a saved model")
64+
parser.add_argument('--finetune', default= False, action = 'store_true',
65+
help = "If only the shared model is to be loaded with saved pre-trained multi-task model.\
66+
In this case, you can specify your own tasks with task file and use the pre-trained shared model\
67+
to finetune upon.")
68+
parser.add_argument('--debug_mode', default = False, action = 'store_true',
69+
help = "record logs for debugging if True")
70+
parser.add_argument('--silent', default = False, action = 'store_true',
71+
help = "Only write logs to file if True")
7472
return parser
7573

76-
7774
parser = argparse.ArgumentParser()
7875
parser = make_arguments(parser)
7976
args = parser.parse_args()
@@ -193,9 +190,8 @@ def main():
193190
allParams['gpu'] = torch.cuda.is_available()
194191
logger.info('task parameters:\n {}'.format(taskParams.taskDetails))
195192

196-
if args.tensorboard:
197-
tensorboard = SummaryWriter(log_dir = os.path.join(logDir, 'tb_logs'))
198-
logger.info("Tensorboard writing at {}".format(os.path.join(logDir, 'tb_logs')))
193+
tensorboard = SummaryWriter(log_dir = os.path.join(logDir, 'tb_logs'))
194+
logger.info("Tensorboard writing at {}".format(os.path.join(logDir, 'tb_logs')))
199195

200196
# making handlers for train
201197
logger.info("Creating data handlers for training...")
@@ -268,11 +264,11 @@ def main():
268264
taskName,
269265
avgLoss,
270266
model.taskLoss.item()))
271-
if args.tensorboard:
272-
tensorboard.add_scalar('train/avg_loss', avgLoss, global_step= model.globalStep)
273-
tensorboard.add_scalar('train/{}_loss'.format(taskName),
274-
model.taskLoss.item(),
275-
global_step=model.globalStep)
267+
268+
tensorboard.add_scalar('train/avg_loss', avgLoss, global_step= model.globalStep)
269+
tensorboard.add_scalar('train/{}_loss'.format(taskName),
270+
model.taskLoss.item(),
271+
global_step=model.globalStep)
276272

277273
if args.save_per_updates > 0 and ( (model.globalStep+1) % args.save_per_updates)==0 and (model.accumulatedStep+1==args.grad_accumulation_steps):
278274
savePath = os.path.join(args.out_dir, 'multi_task_model_{}_{}.pt'.format(epoch,

utils/tranform_functions.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -672,37 +672,6 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
672672
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
673673
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
674674

675-
def query_correctness_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
676-
677-
"""
678-
679-
- Query correctness transformed file
680-
681-
For using this transform function, set ``transform_func`` : **query_correctness_to_tsv** in transform file.
682-
683-
Args:
684-
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
685-
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
686-
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
687-
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary of function specific parameters. Not required for this transformation function.
688-
689-
"""
690-
print('Making data from file {}'.format(readFile))
691-
df = pd.read_csv(os.path.join(dataDir, readFile), sep='\t', header=None, names = ['query', 'label'])
692-
693-
# we consider anything above 0.6 as structured query (3 or more annotations as structured), and others as non-structured
694-
695-
#df['label'] = [str(lab) for lab in df['label']]
696-
df['label'] = [int(lab>=0.6)for lab in df['label']]
697-
698-
data = [ [str(i), str(row['label']), row['query'] ] for i, row in df.iterrows()]
699-
700-
wrtDf = pd.DataFrame(data, columns = ['uid', 'label', 'query'])
701-
702-
#writing
703-
wrtDf.to_csv(os.path.join(wrtDir, 'query_correctness_{}'.format(readFile)), sep="\t", index=False, header=False)
704-
print('File saved at: ', os.path.join(wrtDir, 'query_correctness_{}'.format(readFile)))
705-
706675
def clinc_out_of_scope_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
707676

708677
"""

0 commit comments

Comments
 (0)