-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_ner.py
120 lines (98 loc) · 5.22 KB
/
train_ner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import os
import random
from tqdm import trange
import warnings
from argparse import ArgumentParser
import coloredlogs
import torch
import yaml
from datetime import datetime
from datasets import utils
from models.baseline import Baseline
from models.majority_classifier import MajorityClassifier
from models.maml import MAML
from models.nearest_neighbor import NearestNeighborClassifier
from models.proto_network import PrototypicalNetwork
logger = logging.getLogger('MetaLearningLog')
coloredlogs.install(logger=logger, level='DEBUG', fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
warnings.filterwarnings("ignore", category=UserWarning)
def load_config(config_file):
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
config['base_path'] = os.path.dirname(os.path.abspath(__file__))
config['stamp'] = "stable" # str(datetime.now()).replace(':', '-').replace(' ', '_')
return config
if __name__ == '__main__':
# Parse arguments
parser = ArgumentParser()
parser.add_argument('--config', dest='config_file', type=str, help='Configuration file', required=True)
parser.add_argument('--multi_gpu', action='store_true')
args = parser.parse_args()
# Load configuration
config = load_config(args.config_file)
config['multi_gpu'] = args.multi_gpu
logger.info('Using configuration: {}'.format(config))
# Set seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
# Episodes for meta-training, meta-validation and meta-testing
train_episodes, val_episodes = [], []
# Directory for saving models
os.makedirs(os.path.join(config['base_path'], 'saved_models'), exist_ok=True)
# Path for NER dataset
ner_base_path = os.path.join(config['base_path'], '../data/ontonotes-bert/')
# ner_train_path = os.path.join(ner_base_path, 'dev-g1-traincls-{}shot.txt'.format(str(config['num_test_samples']['ner'])))
# ner_val_path = os.path.join(ner_base_path, 'test-g1-testcls-{}shot.txt'.format(str(config['num_test_samples']['ner'])))
# ner_test_path = os.path.join(ner_base_path, 'test-g1-testcls-{}shot.txt'.format(str(config['num_test_samples']['ner'])))
ner_train_path = os.path.join(ner_base_path, 'train.txt')
ner_val_path = os.path.join(ner_base_path, 'dev.txt')
ner_test_path = os.path.join(ner_base_path, 'test.txt')
labels_train = os.path.join(ner_base_path, 'labels-g1-train.txt')
labels_test = os.path.join(ner_base_path, 'labels-g1-test.txt')
# Generate episodes for NER
logger.info('Generating episodes for NER')
ner_train_episodes, _ = utils.generate_ner_episodes(dir=ner_train_path,
labels_file=labels_train,
n_episodes=config['num_train_episodes']['ner'],
n_support_examples=config['num_shots']['ner'],
n_query_examples=config['num_test_samples']['ner'],
task='ner',
meta_train=True)
ner_val_episodes, _ = utils.generate_ner_episodes(dir=ner_val_path,
labels_file=labels_test,
n_episodes=config['num_val_episodes']['ner'],
n_support_examples=config['num_shots']['ner'],
n_query_examples=config['num_test_samples']['ner'],
task='ner',
meta_train=False)
train_episodes.extend(ner_train_episodes)
val_episodes.extend(ner_val_episodes)
logger.info('Finished generating episodes for NER')
# Initialize meta learner
if config['meta_learner'] == 'maml':
meta_learner = MAML(config)
elif config['meta_learner'] == 'proto_net':
meta_learner = PrototypicalNetwork(config)
elif config['meta_learner'] == 'baseline':
meta_learner = Baseline(config)
elif config['meta_learner'] == 'majority':
meta_learner = MajorityClassifier()
elif config['meta_learner'] == 'nearest_neighbor':
meta_learner = NearestNeighborClassifier(config)
else:
raise NotImplementedError
# Meta-training
meta_learner.training(train_episodes, val_episodes)
logger.info('Meta-learning completed')
# Meta-testing
for _ in trange(5):
test_episodes, label_map = utils.generate_ner_episodes(dir=ner_test_path,
labels_file=labels_test,
n_episodes=config['num_test_episodes']['ner'],
n_support_examples=config['num_shots']['ner'],
n_query_examples=config['num_test_samples']['ner'],
task='ner',
meta_train=False)
meta_learner.testing(test_episodes, label_map)
logger.info('Meta-testing completed')