Skip to content

Commit 62533b0

Browse files
committed
AQA 적용
1 parent 5b13a7f commit 62533b0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+13722
-0
lines changed

codes/active-qa/px/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright 2018 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Wraps the BiDAF model for use as an environment.
16+
17+
This environment is used for the SQuAD task. The environment uses a BiDAF
18+
model to produce an answer on a specified SQuAD datapoint to a new question
19+
rather than the original.
20+
"""
21+
22+
from __future__ import absolute_import
23+
from __future__ import division
24+
from __future__ import print_function
25+
26+
import json
27+
import math
28+
import nltk
29+
import os
30+
import tensorflow as tf
31+
32+
33+
34+
from third_party.bi_att_flow.basic import read_data as bidaf_data
35+
from third_party.bi_att_flow.basic import cli as bidaf_cli
36+
from third_party.bi_att_flow.basic import evaluator as bidaf_eval
37+
from third_party.bi_att_flow.basic import graph_handler as bidaf_graph
38+
from third_party.bi_att_flow.basic import model as bidaf_model
39+
40+
41+
class BidafEnvironment(object):
42+
"""Environment containing the BiDAF model.
43+
44+
This environment loads a BiDAF model and preprocessed data for a chosen SQuAD
45+
dataset. The environment is queried with a pointer to an existing datapoint,
46+
which contains a preprocessed SQuAD document, and a question. BiDAF is run
47+
using the given question against the document and the top answer with its
48+
score is returned.
49+
50+
Attributes:
51+
config: BiDAF configuration read from cli.py
52+
data: Pre-processed SQuAD dataset.
53+
evaluator: BiDAF evaluation object.
54+
graph_handler: BiDAF object used to manage the TF graph.
55+
sess: single Tensorflow session used by the environment.
56+
model: A BiDAF Model object.
57+
"""
58+
59+
def __init__(self,
60+
data_dir,
61+
shared_path,
62+
model_dir,
63+
docid_separator='###',
64+
debug_mode=False,
65+
load_test=False,
66+
load_impossible_questions=False):
67+
"""Constructor loads the BiDAF configuration, model and data.
68+
69+
Args:
70+
data_dir: Directory containing preprocessed SQuAD data.
71+
shared_path: Path to shared data generated at training time.
72+
model_dir: Directory contining parameters of a pre-trained BiDAF model.
73+
docid_separator: Separator used to split suffix off the docid string.
74+
debug_mode: If true logs additional debug information.
75+
load_test: Whether the test set should be loaded as well.
76+
load_impossible_questions: Whether info about impossibility of questions
77+
should be loaded.
78+
"""
79+
self.config = bidaf_cli.get_config()
80+
self.config.save_dir = model_dir
81+
self.config.data_dir = data_dir
82+
self.config.shared_path = shared_path
83+
self.config.mode = 'forward'
84+
self.docid_separator = docid_separator
85+
self.debug_mode = debug_mode
86+
87+
self.datasets = ['train', 'dev']
88+
if load_test:
89+
self.datasets.append('test')
90+
91+
data_filter = None
92+
self.data = dict()
93+
for dataset in self.datasets:
94+
self.data[dataset] = bidaf_data.read_data(
95+
self.config, dataset, True, data_filter=data_filter)
96+
bidaf_data.update_config(self.config, self.data.values())
97+
98+
models = bidaf_model.get_model(self.config)
99+
self.evaluator = bidaf_eval.MultiGPUF1Evaluator(self.config, models)
100+
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
101+
self.graph_handler = bidaf_graph.GraphHandler(self.config, models[0])
102+
self.graph_handler.initialize(self.sess)
103+
104+
nltk_data_path = os.path.join(os.path.expanduser('~'), 'data')
105+
nltk.data.path.append(nltk_data_path)
106+
107+
self.impossible_ids = set()
108+
if load_impossible_questions:
109+
tf.logging.info('Loading impossible question ids.')
110+
for dataset in self.datasets:
111+
self.impossible_ids.update(self._ReadImpossiblities(dataset, data_dir))
112+
if self.debug_mode:
113+
tf.logging.info('Loaded {} impossible question ids.'.format(
114+
len(self.impossible_ids)))
115+
116+
def _ReadImpossiblities(self, dataset, data_dir):
117+
"""Collects all the docids for impossible questions."""
118+
data_path = os.path.join(data_dir, '{}-v2.0.json'.format(dataset))
119+
impossible_ids = []
120+
with tf.gfile.Open(data_path, 'r') as fh:
121+
data = json.load(fh)
122+
for document in data['data']:
123+
for paragraph in document['paragraphs']:
124+
for question in paragraph['qas']:
125+
if question['is_impossible']:
126+
impossible_ids.append(question['id'])
127+
128+
if self.debug_mode:
129+
tf.logging.info('Loaded {} impossible question ids from {}.'.format(
130+
len(impossible_ids), dataset))
131+
return impossible_ids
132+
133+
def _WordTokenize(self, text):
134+
"""Tokenizes the text NLTK for consistency with BiDAF."""
135+
return [
136+
token.replace("''", '"').replace('``', '"')
137+
for token in nltk.word_tokenize(text)
138+
]
139+
140+
def _PreprocessQaData(self, questions, document_ids):
141+
"""Prepares the BiDAF Data object.
142+
143+
Loads a batch of SQuAD datapoints, identified by their 'ids' field. The
144+
questions are replaced with those specified in the input. All datapoints
145+
must come from the same original dataset (train, dev or test), else the
146+
shared data will be incorrect. The first id in document_ids is used to
147+
determine the dataset, a KeyError is thrown if the other ids are not in this
148+
dataset.
149+
150+
Args:
151+
questions: List of strings used to replace the original question.
152+
document_ids: Identifiers for the SQuAD datapoints to use.
153+
154+
Returns:
155+
data: BiDAF Data object containing the desired datapoints only.
156+
data.shared: The appropriate shared data from the dataset containing
157+
the ids in document_ids
158+
id2questions_dict: A dict mapping docids to original questions and
159+
rewrites.
160+
161+
Raises:
162+
KeyError: Occurs if it is not the case that all document_ids are present
163+
in a single preloaded dataset.
164+
"""
165+
first_docid = document_ids[0].split(self.docid_separator)[0]
166+
if first_docid in self.data['train'].data['ids']:
167+
dataset = self.data['train']
168+
elif first_docid in self.data['dev'].data['ids']:
169+
dataset = self.data['dev']
170+
elif 'test' in self.data and first_docid in self.data['test'].data['ids']:
171+
dataset = self.data['test']
172+
else:
173+
raise KeyError('Document id not present: {}'.format(first_docid))
174+
data_indices = [
175+
dataset.data['ids'].index(document_ids[i].split(
176+
self.docid_separator)[0]) for i in range(len(document_ids))
177+
]
178+
179+
data_out = dict()
180+
# Copies relevant datapoint, retaining the input docids.
181+
for key in dataset.data.iterkeys():
182+
if key == 'ids':
183+
data_out[key] = document_ids
184+
else:
185+
data_out[key] = [dataset.data[key][i] for i in data_indices]
186+
if self.debug_mode:
187+
for q in data_out['q']:
188+
tf.logging.info('Original question: {}'.format(
189+
' '.join(q).encode('utf-8')))
190+
191+
# Replaces the question in the datapoint for the rewrite.
192+
id2questions_dict = dict()
193+
for i in range(len(questions)):
194+
id2questions_dict[data_out['ids'][i]] = dict()
195+
id2questions_dict[data_out['ids'][i]]['original'] = ' '.join(
196+
data_out['q'][i])
197+
data_out['q'][i] = self._WordTokenize(questions[i])
198+
199+
if len(data_out['q'][i]) > self.config.max_ques_size:
200+
tf.logging.info('Truncated question from {} to {}'.format(
201+
len(data_out['q'][i]), self.config.max_ques_size))
202+
data_out['q'][i] = data_out['q'][i][:self.config.max_ques_size]
203+
204+
id2questions_dict[data_out['ids'][i]]['raw_rewrite'] = questions[i]
205+
id2questions_dict[data_out['ids'][i]]['rewrite'] = ' '.join(
206+
data_out['q'][i])
207+
data_out['cq'][i] = [list(qij) for qij in data_out['q'][i]]
208+
209+
if self.debug_mode:
210+
for q in data_out['q']:
211+
tf.logging.info('New question: {}'.format(
212+
' '.join(q).encode('utf-8')))
213+
214+
return data_out, dataset.shared, id2questions_dict
215+
216+
def IsImpossible(self, document_id):
217+
return document_id in self.impossible_ids
218+
219+
def GetAnswers(self, questions, document_ids):
220+
"""Computes an answer for a given question from a SQuAD datapoint.
221+
222+
Runs a BiDAF model on a specified SQuAD datapoint, but using the input
223+
question in place of the original.
224+
225+
Args:
226+
questions: List of strings used to replace the original question.
227+
document_ids: Identifiers for the SQuAD datapoints to use.
228+
229+
Returns:
230+
e.id2answer_dict: A dict containing the answers and their scores.
231+
e.loss: Scalar training loss for the entire batch.
232+
id2questions_dict: A dict mapping docids to original questions and
233+
rewrites.
234+
235+
Raises:
236+
ValueError: If the number of questions and document_ids differ.
237+
ValueError: If the document_ids are not unique.
238+
"""
239+
if len(questions) != len(document_ids):
240+
raise ValueError('Number of questions and document_ids must be equal.')
241+
if len(document_ids) > len(set(document_ids)):
242+
raise ValueError('document_ids must be unique.')
243+
raw_data, shared, id2questions_dict = self._PreprocessQaData(
244+
questions, document_ids)
245+
data = bidaf_data.DataSet(raw_data, data_type='', shared=shared)
246+
247+
num_batches = int(math.ceil(data.num_examples / self.config.batch_size))
248+
e = None
249+
for multi_batch in data.get_multi_batches(
250+
self.config.batch_size, self.config.num_gpus, num_steps=num_batches):
251+
ei = self.evaluator.get_evaluation(self.sess, multi_batch)
252+
e = ei if e is None else e + ei
253+
if self.debug_mode:
254+
tf.logging.info(e)
255+
self.graph_handler.dump_answer(e, path=self.config.answer_path)
256+
self.graph_handler.dump_eval(e, path=self.config.eval_path)
257+
return e.id2answer_dict, id2questions_dict, e.loss

0 commit comments

Comments
 (0)