Skip to content

Commit

Permalink
Add prediction functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
monologg committed Dec 22, 2019
1 parent f4556dd commit 1d631ca
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 7 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
- The number of labels are based on the _train_ dataset.
- Add `UNK` for labels (For intent and slot labels which are only shown in _dev_ and _test_ dataset)

## Usage
## Training & Evaluation

```bash
$ python3 main.py --task {task_name} \
Expand All @@ -49,6 +49,20 @@ $ python3 main.py --task snips \
--do_train --do_eval
```

## Prediction

- There should be a trained model before running prediction.
- You should write sentences in `preds.txt` in `preds` directory.

```bash
$ python3 main.py --task snips \
--model_type bert \
--model_dir snips_model \
--do_pred \
--pred_dir preds \
--pred_input_file preds.txt
```

## Results

Run 5 epochs each (No hyperparameter tuning)
Expand All @@ -68,6 +82,7 @@ Run 5 epochs each (No hyperparameter tuning)

- 2019/12/03: Add DistilBert and RoBERTa result
- 2019/12/14: Add Albert(large v1) result
- 2019/12/22: Available to predict sentences

## References

Expand Down
32 changes: 26 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
import argparse

from trainer import Trainer
from utils import init_logger, load_tokenizer, MODEL_CLASSES, MODEL_PATH_MAP
from utils import init_logger, load_tokenizer, read_prediction_text, MODEL_CLASSES, MODEL_PATH_MAP
from data_loader import load_examples


def main(args):
init_logger()
tokenizer = load_tokenizer(args)
train_dataset = load_examples(args, tokenizer, mode="train")
dev_dataset = load_examples(args, tokenizer, mode="dev")
test_dataset = load_examples(args, tokenizer, mode="test")

train_dataset = None
dev_dataset = None
test_dataset = None

if args.do_train:
train_dataset = load_examples(args, tokenizer, mode="train")
dev_dataset = load_examples(args, tokenizer, mode="dev")

if args.do_eval:
test_dataset = load_examples(args, tokenizer, mode="test")

trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)

if args.do_train:
trainer.train()

if args.do_eval:
trainer.load_model()
trainer.evaluate("dev")
trainer.evaluate("test")

if args.do_pred:
trainer.load_model()
texts = read_prediction_text(args)
trainer.predict(texts, tokenizer)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -55,7 +68,14 @@ def main(args):
parser.add_argument("--no_lower_case", action="store_true", help="Whether not to lowercase the text (For cased model)")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")

parser.add_argument("--ignore_index", default=-100, type=int, help='Specifies a target value that is ignored and does not contribute to the input gradient')
parser.add_argument("--ignore_index", default=-100, type=int,
help='Specifies a target value that is ignored and does not contribute to the input gradient')

# For prediction
parser.add_argument("--pred_dir", default="./preds", type=str, help="The input prediction dir")
parser.add_argument("--pred_input_file", default="preds.txt", type=str, help="The input text file of lines for prediction")
parser.add_argument("--pred_output_file", default="outputs.txt", type=str, help="The output file of prediction")
parser.add_argument("--do_pred", action="store_true", help="Whether to predict the sentences")

args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions preds/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
outputs.*
5 changes: 5 additions & 0 deletions preds/preds.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
i would like to find a flight from charlotte to las vegas that makes a stop in st. louis
on april first i need a ticket from tacoma to san jose departing before 7 am
on april first i need a flight going from phoenix to san diego
i would like a flight traveling one way from phoenix to san diego on april first
i would like a flight from orlando to salt lake city for april first on delta airlines
139 changes: 139 additions & 0 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def train(self):
for step, batch in enumerate(epoch_iterator):
self.model.train()
batch = tuple(t.to(self.device) for t in batch) # GPU or CPU

logger.info("batch[0].size():", batch[0].size())
logger.info("batch[1].size():", batch[1].size())
logger.info("batch[2].size():", batch[2].size())

inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'intent_label_ids': batch[3],
Expand Down Expand Up @@ -231,3 +236,137 @@ def load_model(self):
logger.info("***** Model Loaded *****")
except:
raise Exception("Some model files might be missing...")

def _convert_texts_to_tensors(self, texts, tokenizer,
cls_token_segment_id=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True):
"""
Only add input_ids, attention_mask, token_type_ids
Labels aren't required.
"""
# Setting based on the current model type
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token
unk_token = tokenizer.unk_token
pad_token_id = tokenizer.pad_token_id

input_ids_batch = []
attention_mask_batch = []
token_type_ids_batch = []
slot_label_mask_batch = []

for text in texts:
tokens = []
slot_label_mask = []
for word in text.split():
word_tokens = tokenizer.tokenize(word)
if not word_tokens:
word_tokens = [unk_token] # For handling the bad-encoded word
tokens.extend(word_tokens)
# Real label position as 0 for the first token of the word, and padding ids for the remaining tokens
slot_label_mask.extend([0] + [self.pad_token_label_id] * (len(word_tokens) - 1))

# Account for [CLS] and [SEP]
special_tokens_count = 2
if len(tokens) > self.args.max_seq_len - special_tokens_count:
tokens = tokens[:(self.args.max_seq_len - special_tokens_count)]
slot_label_mask = slot_label_mask[:(self.args.max_seq_len - special_tokens_count)]

# Add [SEP] token
tokens += [sep_token]
slot_label_mask += [self.pad_token_label_id]
token_type_ids = [sequence_a_segment_id] * len(tokens)

# Add [CLS] token
tokens = [cls_token] + tokens
slot_label_mask = [self.pad_token_label_id] + slot_label_mask
token_type_ids = [cls_token_segment_id] + token_type_ids

input_ids = tokenizer.convert_tokens_to_ids(tokens)

# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

# Zero-pad up to the sequence length.
padding_length = self.args.max_seq_len - len(input_ids)
input_ids = input_ids + ([pad_token_id] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
slot_label_mask = slot_label_mask + ([self.pad_token_label_id] * padding_length)

input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
token_type_ids_batch.append(token_type_ids)
slot_label_mask_batch.append(slot_label_mask)

# Making tensor that is batch size of 1
input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long).to(self.device)
attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long).to(self.device)
token_type_ids_batch = torch.tensor(token_type_ids_batch, dtype=torch.long).to(self.device)
slot_label_mask_batch = torch.tensor(slot_label_mask_batch, dtype=torch.long).to(self.device)

return input_ids_batch, attention_mask_batch, token_type_ids_batch, slot_label_mask_batch

def predict(self, orig_texts, tokenizer):
texts = []
if not self.args.no_lower_case:
for cased_text in orig_texts:
texts.append(cased_text.lower())
else:
texts = orig_texts

batch = self._convert_texts_to_tensors(texts, tokenizer)

slot_label_mask = batch[3]
self.model.eval()

# We have only one batch
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'intent_label_ids': None,
'slot_labels_ids': None}
if self.args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2]
outputs = self.model(**inputs)
_, (intent_logits, slot_logits) = outputs[:2]

print(intent_logits.size())
print(slot_logits.size())

# Intent prediction
intent_preds = intent_logits.detach().cpu().numpy()
intent_preds = np.argmax(intent_preds, axis=1)
intent_list = []
for intent_idx in intent_preds:
intent_list.append(self.intent_label_lst[intent_idx])

# Slot prediction
slot_preds = slot_logits.detach().cpu().numpy()
slot_preds = np.argmax(slot_preds, axis=2)
out_slot_labels_ids = slot_label_mask.detach().cpu().numpy()

slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
out_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]

for i in range(out_slot_labels_ids.shape[0]):
for j in range(out_slot_labels_ids.shape[1]):
if out_slot_labels_ids[i, j] != self.pad_token_label_id:
# out_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]])
slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])

# Make output.txt with texts, intent_list and slot_preds_list
with open(os.path.join(self.args.pred_dir, self.args.pred_output_file), 'w', encoding='utf-8') as f:
for text, intent, slots in zip(orig_texts, intent_list, slot_preds_list):
f.write("text: {}\n".format(text))
f.write("intent: {}\n".format(intent))
f.write("slots: {}\n".format(' '.join(slots)))
f.write("\n")

# print output.json
with open(os.path.join(self.args.pred_dir, self.args.pred_output_file), 'r', encoding='utf-8') as f:
print(f.read())
4 changes: 4 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ def acc_and_f1(preds, labels, average='macro'):
"intent_acc": acc,
"intent_f1": f1,
}


def read_prediction_text(args):
return [text.strip() for text in open(os.path.join(args.pred_dir, args.pred_input_file), 'r', encoding='utf-8')]

0 comments on commit 1d631ca

Please sign in to comment.