Skip to content

Commit ebefeaf

Browse files
committed
adding example 4
1 parent 3c07346 commit ebefeaf

File tree

9 files changed

+242
-14
lines changed

9 files changed

+242
-14
lines changed

docs/source/data_transformations.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Sample transform functions
1515
^^^^^^^^^^^^^^^^^^^^^^^^^^
1616
.. automodule:: utils.tranform_functions
1717
:members: snips_intent_ner_to_tsv, snli_entailment_to_tsv,create_fragment_detection_tsv,
18-
msmarco_answerability_detection_to_tsv, bio_ner_to_tsv, msmarco_query_type_to_tsv, qqp_query_similarity_to_tsv
18+
msmarco_answerability_detection_to_tsv, msmarco_query_type_to_tsv, bio_ner_to_tsv, msmarco_query_type_to_tsv, qqp_query_similarity_to_tsv
1919

2020
Your own transform function
2121
^^^^^^^^^^^^^^^^^^^^^^^^^^^

docs/source/shared_encoder.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,5 @@ For evaluating the performance on dev and test sets during training, we provide
7979

8080
.. automodule:: utils.eval_metrics
8181
:members: classification_accuracy, classification_f1_score, seqeval_f1_score,
82-
seqeval_precision, seqeval_recall, snips_f1_score, snips_precision, snips_recall
82+
seqeval_precision, seqeval_recall, snips_f1_score, snips_precision, snips_recall, classification_recall_score
8383

examples/entailment_detection/entailment_snli.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
"!python ../../data_preparation.py \\\n",
111111
" --task_file 'tasks_file_snli.yml' \\\n",
112112
" --data_dir '../../data' \\\n",
113-
" --max_seq_len 384"
113+
" --max_seq_len 128"
114114
]
115115
},
116116
{
@@ -135,13 +135,13 @@
135135
" --task_file 'tasks_file_snli.yml' \\\n",
136136
" --out_dir 'snli_entailment_bert_base' \\\n",
137137
" --epochs 3 \\\n",
138-
" --train_batch_size 8 \\\n",
139-
" --eval_batch_size 16 \\\n",
140-
" --grad_accumulation_steps 2 \\\n",
141-
" --log_per_updates 50 \\\n",
138+
" --train_batch_size 64 \\\n",
139+
" --eval_batch_size 64 \\\n",
140+
" --grad_accumulation_steps 1 \\\n",
141+
" --log_per_updates 100 \\\n",
142142
" --eval_while_train True \\\n",
143143
" --test_while_train True \\\n",
144-
" --max_seq_len 384 \\\n",
144+
" --max_seq_len 128 \\\n",
145145
" --silent True "
146146
]
147147
},
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"!wget https://msmarco.blob.core.windows.net/msmarco/train_v2.1.json.gz -P msmarco_qna_data\n",
10+
"!wget https://msmarco.blob.core.windows.net/msmarco/dev_v2.1.json.gz -P msmarco_qna_data\n",
11+
"!wget https://msmarco.blob.core.windows.net/msmarco/eval_v2.1_public.json.gz -P msmarco_qna_data"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"!gunzip msmarco_qna_data/train_v2.1.json.gz\n",
21+
"!gunzip msmarco_qna_data/dev_v2.1.json.gz\n",
22+
"!gunzip msmarco_qna_data/eval_v2.1_public.json.gz"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"!python ../../data_transformations.py \\\n",
32+
" --transform_file 'transform_file_querytype.yml'"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": null,
38+
"metadata": {},
39+
"outputs": [],
40+
"source": [
41+
"!python ../../data_preparation.py \\\n",
42+
" --task_file 'tasks_file_querytype.yml' \\\n",
43+
" --data_dir '../../data' \\\n",
44+
" --max_seq_len 60"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"metadata": {},
51+
"outputs": [],
52+
"source": [
53+
"!python ../../train.py \\\n",
54+
" --data_dir '../../data/bert-base-uncased_prepared_data' \\\n",
55+
" --task_file 'tasks_file_querytype.yml' \\\n",
56+
" --out_dir 'msmarco_querytype_bert_base' \\\n",
57+
" --epochs 3 \\\n",
58+
" --train_batch_size 64 \\\n",
59+
" --eval_batch_size 64 \\\n",
60+
" --grad_accumulation_steps 1 \\\n",
61+
" --log_per_updates 100 \\\n",
62+
" --eval_while_train True \\\n",
63+
" --test_while_train True \\\n",
64+
" --max_seq_len 60 \\\n",
65+
" --silent True"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {},
72+
"outputs": [],
73+
"source": []
74+
}
75+
],
76+
"metadata": {
77+
"kernelspec": {
78+
"display_name": "Python 3",
79+
"language": "python",
80+
"name": "python3"
81+
},
82+
"language_info": {
83+
"codemirror_mode": {
84+
"name": "ipython",
85+
"version": 3
86+
},
87+
"file_extension": ".py",
88+
"mimetype": "text/x-python",
89+
"name": "python",
90+
"nbconvert_exporter": "python",
91+
"pygments_lexer": "ipython3",
92+
"version": "3.7.3"
93+
}
94+
},
95+
"nbformat": 4,
96+
"nbformat_minor": 4
97+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
querytype:
2+
model_type: BERT
3+
config_name: bert-base-uncased
4+
dropout_prob: 0.2
5+
label_map_or_file:
6+
- DESCRIPTION
7+
- ENTITY
8+
- LOCATION
9+
- NUMERIC
10+
- PERSON
11+
metrics:
12+
- classification_accuracy
13+
loss_type: CrossEntropyLoss
14+
task_type: SingleSenClassification
15+
file_names:
16+
- querytype_train_v2.1.tsv
17+
- querytype_dev_v2.1.tsv
18+
- querytype_eval_v2.1_public.tsv
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
transform1:
2+
transform_func: msmarco_query_type_to_tsv
3+
transform_params:
4+
data_frac : 0.2
5+
read_file_names:
6+
- train_v2.1.json
7+
- dev_v2.1.json
8+
- eval_v2.1_public.json
9+
10+
read_dir: msmarco_qna_data
11+
save_dir: ../../data

utils/data_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
"seqeval_recall" : seqeval_recall,
2626
"snips_f1_score" : snips_f1_score,
2727
"snips_precision" : snips_precision,
28-
"snips_recall" : snips_recall
28+
"snips_recall" : snips_recall,
29+
"classification_recall" : classification_recall
2930
}
3031

3132
TRANSFORM_FUNCS = {
@@ -37,7 +38,8 @@
3738
"msmarco_query_type_to_tsv" : msmarco_query_type_to_tsv,
3839
"imdb_sentiment_data_to_tsv" : imdb_sentiment_data_to_tsv,
3940
"qqp_query_similarity_to_tsv" : qqp_query_similarity_to_tsv,
40-
"msmarco_answerability_detection_to_tsv" : msmarco_answerability_detection_to_tsv
41+
"msmarco_answerability_detection_to_tsv" : msmarco_answerability_detection_to_tsv,
42+
"clinc_out_of_scope_to_tsv" : clinc_out_of_scope_to_tsv
4143
}
4244

4345
class ModelType(IntEnum):

utils/eval_metrics.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
1+
"""
2+
File for creating metric functions
3+
"""
24
from sklearn.metrics import accuracy_score, f1_score
5+
from sklearn.metrics import recall_score as class_recall_score
36
from seqeval.metrics import f1_score as seq_f1
47
from seqeval.metrics import precision_score, recall_score
58

@@ -31,6 +34,19 @@ def classification_f1_score(yTrue, yPred):
3134
"""
3235
return f1_score(yTrue, yPred, average='micro')
3336

37+
def classification_recall(yTrue, yPred):
38+
"""
39+
Standard recall score from sklearn for classification tasks.
40+
It takes a batch of predictions and labels.
41+
42+
To use this metric, add **classification_f1_score** into list of ``metrics`` in task file.
43+
44+
Args:
45+
yPred (:obj:`list`) : [0, 2, 1, 3]
46+
yTrue (:obj:`list`) : [0, 1, 2, 3]
47+
48+
"""
49+
return class_recall_score(yTrue, yPred, average='micro')
3450

3551
def seqeval_f1_score(yTrue, yPred):
3652
"""

utils/tranform_functions.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import os
44
import re
55
import json
6+
import random
67
import pandas as pd
78
from tqdm import tqdm
9+
from collections import defaultdict
810
from statistics import median
911
from sklearn.model_selection import train_test_split
1012
SEED = 42
@@ -440,12 +442,12 @@ def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrain
440442
#saving
441443
print('number of samples in final data : ', len(dfKeep))
442444
print('writing for file {} at {}'.format(readFile, wrtDir))
443-
dfKeep.to_csv(os.path.join(wrtDir, 'querytype_{}.tsv'.format(readFile.split('.')[0])), sep='\t',
445+
dfKeep.to_csv(os.path.join(wrtDir, 'querytype_{}.tsv'.format(readFile.lower().replace('.json', ''))), sep='\t',
444446
index=False, header=False)
445447
if isTrainFile:
446448
allClasses = dfKeep['query_type'].unique()
447449
labelMap = {lab : i for i, lab in enumerate(allClasses)}
448-
labelMapPath = os.path.join(wrtDir, 'querytype_{}_label_map.joblib'.format(readFile.split('.')[0]))
450+
labelMapPath = os.path.join(wrtDir, 'querytype_{}_label_map.joblib'.format(readFile.lower().replace('.json', '')))
449451
joblib.dump(labelMap, labelMapPath)
450452
print('Created label map file at', labelMapPath)
451453

@@ -668,4 +670,86 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
668670
print('Dev file written at: ', os.path.join(wrtDir, 'msmarco_answerability_dev.tsv'))
669671

670672
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
671-
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
673+
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
674+
675+
def clinc_out_of_scope_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
676+
677+
"""
678+
679+
For using this transform function, set ``transform_func`` : **clinc_out_of_scope_to_tsv** in transform file.
680+
681+
Args:
682+
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
683+
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
684+
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
685+
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary requiring the following parameters as key-value
686+
687+
- ``samples_per_intent_train`` (defaults to 7) : Number of in-scope samples per intent to consider, as this data has imbalance for inscope and outscope
688+
689+
"""
690+
transParamDict.setdefault("samples_per_intent_train", 7)
691+
692+
print("Making data from file {} ...".format(readFile))
693+
raw = json.load(open(os.path.join(dataDir, readFile)))
694+
695+
print('Num of train samples in-scope: ', len(raw['train']))
696+
inScopeTrain = defaultdict(list)
697+
for sentence, intent in raw['train']:
698+
inScopeTrain[intent].append(sentence)
699+
700+
#sampling
701+
inscopeSampledTrain = []
702+
numSamplesPerInt = 7
703+
random.seed(SEED)
704+
for intent in inScopeTrain:
705+
inscopeSampledTrain += random.sample(inScopeTrain[intent], int(transParamDict["samples_per_intent_train"]))
706+
707+
print('Num of sampled train samples in-scope: ', len(inscopeSampledTrain))
708+
#out of scope train
709+
outscopeTrain = [sample[0] for sample in raw['oos_train']]
710+
print('Num of train out-scope samples: ', len(outscopeTrain))
711+
712+
#train data
713+
allTrain = inscopeSampledTrain + outscopeTrain
714+
allTrainLabels = [1]*len(inscopeSampledTrain) + [0]*len(outscopeTrain)
715+
716+
#writing train data file
717+
trainF = open(os.path.join(wrtDir, 'clinc_outofscope_train.tsv'), 'w')
718+
for uid, (samp, lab) in enumerate(zip(allTrain, allTrainLabels)):
719+
trainF.write("{}\t{}\t{}\n".format(uid, lab, samp))
720+
print('Train file written at: ', os.path.join(wrtDir, 'clinc_outofscope_train.tsv'))
721+
trainF.close()
722+
723+
#making dev set
724+
inscopeDev = [sample[0] for sample in raw['val']]
725+
outscopeDev = [sample[0] for sample in raw['oos_val']]
726+
print('Num of val out-scope samples: ', len(outscopeDev))
727+
print('Num of val in-scope samples: ', len(inscopeDev))
728+
729+
#allDev = inscopeDev + outscopeDev
730+
allDev = outscopeDev
731+
#allDevLabels = [1]*inscopeDev + [0]*outscopeDev
732+
allDevLabels = [0]*len(outscopeDev)
733+
734+
#writing dev data file
735+
devF = open(os.path.join(wrtDir, 'clinc_outofscope_dev.tsv'), 'w')
736+
for uid, (samp, lab) in enumerate(zip(allDev, allDevLabels)):
737+
devF.write("{}\t{}\t{}\n".format(uid, lab, samp))
738+
print('Dev file written at: ', os.path.join(wrtDir, 'clinc_outofscope_dev.tsv'))
739+
devF.close()
740+
741+
#making test set
742+
inscopeTest = [sample[0] for sample in raw['test']]
743+
outscopeTest = [sample[0] for sample in raw['oos_test']]
744+
print('Num of test out-scope samples: ', len(outscopeTest))
745+
print('Num of test in-scope samples: ', len(inscopeTest))
746+
747+
allTest = inscopeTest + outscopeTest
748+
allTestLabels = [1]*len(inscopeTest) + [0]*len(outscopeTest)
749+
750+
#writing test data file
751+
testF = open(os.path.join(wrtDir, 'clinc_outofscope_test.tsv'), 'w')
752+
for uid, (samp, lab) in enumerate(zip(allTest, allTestLabels)):
753+
testF.write("{}\t{}\t{}\n".format(uid, lab, samp))
754+
print('Test file written at: ', os.path.join(wrtDir, 'clinc_outofscope_test.tsv'))
755+
testF.close()

0 commit comments

Comments
 (0)