-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 parent
0c5ebdc
commit 086ed81
Showing
35 changed files
with
10,277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
.vscode | ||
experiments/ | ||
data/ | ||
dataset_cache* | ||
dataset1_cache* | ||
daily_dialog_* | ||
runs/ | ||
ParlAI/ | ||
__pycache__ | ||
.idea/* | ||
env/* | ||
ParlAI/* | ||
model/* | ||
logs/* | ||
caches/* | ||
_OpenAIGPTTokenizer | ||
out | ||
emp_transfo_checkpoint/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# EmpTransfo: A Multi-head Transformer Architecture for Creating Empathetic Dialog Systems | ||
|
||
The present repo contains the code for the paper https://arxiv.org/abs/2003.02958 | ||
on empathetic dialog system. The repository is heavily influenced by https://github.com/huggingface/transfer-learning-conv-ai | ||
|
||
|
||
## Installation | ||
To install and use the training and inference scripts please clone the repo and install the requirements: | ||
|
||
```bash | ||
git clone git@github.com:roholazandie/EmpTransfo.git | ||
cd EmpTransfo | ||
pip install -r requirements.txt | ||
|
||
``` | ||
|
||
|
||
## Interact with the chatbot | ||
You can download the the checkpoint model [here](https://drive.google.com/open?id=1EjpK0YEVG1i9meLJzt7ZgODr0k65lTDi), extract and point to it from interact_config.json "model_checkpoint" value. | ||
For example: | ||
``` | ||
"model_checkpoint" : "/home/rohola/codes/EmpTransfo/emp_transfo_checkpoint" | ||
``` | ||
Then run interact.py | ||
```python | ||
python interact.py | ||
``` | ||
|
||
## Dataset | ||
The original daily dialog dataset is [here](https://www.aclweb.org/anthology/I17-1099/). We changed the format to our purpose and can be download | ||
from [here](https://drive.google.com/open?id=1T4AdY7wku8srL_xWSxgt-OHqdLFVo3s3). | ||
|
||
|
||
## Training | ||
|
||
The script train_multihead.py uses three heads with all features. | ||
|
||
|
||
The script train_full.py uses two heads (next sentence prediction and LM head), but uses all the features. | ||
|
||
|
||
The script train_emotion_recognition.py trains to predict the next emotion (wihtout no_emotion). | ||
|
||
The script train.py trains without any features of the dataset (the base model). | ||
|
||
For all training scripts just change the dataset_path in config.json file related to that task, and then run the script | ||
without any arguments. | ||
|
||
|
||
|
||
## Citation | ||
If you use this code in your research, you can cite our ANLP paper: | ||
|
||
``` | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import json | ||
|
||
|
||
class Config: | ||
|
||
def __init__(self, | ||
dataset_path="", | ||
dataset_cache="", | ||
model_checkpoint="", | ||
num_candidates=2, | ||
do_lower_case=True, | ||
max_history=2, | ||
train_batch_size=4, | ||
valid_batch_size=4, | ||
gradient_accumulation_steps=8, | ||
lr=5e-5, | ||
warmup_proportion=0.1, | ||
lm_coef=1, | ||
mc_coef=1, | ||
max_norm=10, | ||
n_epochs=2, | ||
personality_permutations=1, | ||
eval_before_start=False, | ||
device="cpu", | ||
fp16="", | ||
local_rank=-1, | ||
log_dir="", | ||
): | ||
self.dataset_path = dataset_path | ||
self.dataset_cache = dataset_cache | ||
self.model_checkpoint = model_checkpoint | ||
self.num_candidates = num_candidates | ||
self.do_lower_case = do_lower_case | ||
self.max_history = max_history | ||
self.train_batch_size = train_batch_size | ||
self.valid_batch_size = valid_batch_size | ||
self.gradient_accumulation_steps = gradient_accumulation_steps | ||
self.lr = lr | ||
self.warmup_proportion = warmup_proportion | ||
self.lm_coef = lm_coef | ||
self.mc_coef = mc_coef | ||
self.max_norm = max_norm | ||
self.n_epochs = n_epochs | ||
self.personality_permutations = personality_permutations | ||
self.eval_before_start = eval_before_start | ||
self.device = device | ||
self.fp16 = fp16 | ||
self.local_rank = local_rank | ||
self.log_dir = log_dir | ||
|
||
@classmethod | ||
def from_dict(cls, json_object): | ||
config = Config() | ||
for key in json_object: | ||
config.__dict__[key] = json_object[key] | ||
return config | ||
|
||
@classmethod | ||
def from_json_file(cls, json_file): | ||
with open(json_file) as f: | ||
config_json = f.read() | ||
|
||
return cls.from_dict(json.loads(config_json)) | ||
|
||
|
||
class InteractConfig: | ||
|
||
def __init__(self, | ||
dataset_path="", | ||
model="", | ||
dataset_cache="", | ||
model_checkpoint="", | ||
max_history="", | ||
device="", | ||
no_sample="", | ||
max_length="", | ||
min_length="", | ||
seed="", | ||
temperature="", | ||
top_k="", | ||
top_p="" | ||
): | ||
self.dataset_path = dataset_path | ||
self.model = model | ||
self.dataset_cache = dataset_cache | ||
self.model_checkpoint = model_checkpoint | ||
self.max_history = max_history | ||
self.device = device | ||
self.no_sample = no_sample | ||
self.max_length = max_length | ||
self.min_length = min_length | ||
self.seed = seed | ||
self.temperature = temperature | ||
self.top_k = top_k | ||
self.top_p = top_p | ||
|
||
@classmethod | ||
def from_dict(cls, json_object): | ||
config = InteractConfig() | ||
for key in json_object: | ||
config.__dict__[key] = json_object[key] | ||
return config | ||
|
||
@classmethod | ||
def from_json_file(cls, json_file): | ||
with open(json_file) as f: | ||
config_json = f.read() | ||
|
||
return cls.from_dict(json.loads(config_json)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"dataset_path" : "/home/rohola/data/daily_dialog_full/daily_dialog.json", | ||
"model" : "openai-gpt", | ||
"dataset_cache" : "./caches/dataset_cache_OpenAIGPTTokenizer", | ||
"model_checkpoint" : "/home/rohola/codes/EmpTransfo/emp_transfo_checkpoint", | ||
"max_history" : 2, | ||
"device" : "cpu", | ||
"no_sample" : true, | ||
"max_length" : 20, | ||
"min_length" : 1, | ||
"seed" : 42, | ||
"temperature" : 0.7, | ||
"top_k" : 0, | ||
"top_p" : 0.9 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"dataset_path": "/home/rohola/data/daily_dialog.json" , | ||
"dataset_cache": "./daily_dialog_dataset_cache", | ||
"model_checkpoint": "openai-gpt", | ||
"num_candidates": 2, | ||
"do_lower_case": true, | ||
"max_history": 2, | ||
"train_batch_size": 1, | ||
"valid_batch_size": 1, | ||
"gradient_accumulation_steps": 8, | ||
"lr": 6.25e-5, | ||
"warmup_proportion": 0.1, | ||
"lm_coef": 1.0, | ||
"mc_coef": 1.0, | ||
"max_norm": 1.0, | ||
"n_epochs": 3, | ||
"personality_permutations":1, | ||
"eval_before_start": false, | ||
"device": "cuda:0", | ||
"fp16": "", | ||
"local_rank": -1, | ||
"log_dir": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"dataset_path": "/home/rohola/data/daily_dialog_full/daily_dialog.json" , | ||
"dataset_cache": "./daily_dialog_dataset_cache", | ||
"model_checkpoint": "openai-gpt", | ||
"num_candidates": 2, | ||
"do_lower_case": true, | ||
"max_history": 2, | ||
"train_batch_size": 1, | ||
"valid_batch_size": 1, | ||
"gradient_accumulation_steps": 8, | ||
"lr": 6.25e-5, | ||
"warmup_proportion": 0.1, | ||
"lm_coef": 1.0, | ||
"mc_coef": 1.0, | ||
"max_norm": 1.0, | ||
"n_epochs": 3, | ||
"personality_permutations":1, | ||
"eval_before_start": false, | ||
"device": "cpu", | ||
"fp16": "", | ||
"local_rank": -1, | ||
"log_dir": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"dataset_path": "/home/rohola/data/daily_dialog_full/daily_dialog.json" , | ||
"dataset_cache": "./caches/daily_dialog_dataset_cache", | ||
"model_checkpoint": "openai-gpt", | ||
"num_candidates": 2, | ||
"do_lower_case": true, | ||
"max_history": 2, | ||
"train_batch_size": 1, | ||
"valid_batch_size": 1, | ||
"gradient_accumulation_steps": 8, | ||
"lr": 6.25e-5, | ||
"warmup_proportion": 0.1, | ||
"lm_coef": 1.0, | ||
"mc_coef": 1.0, | ||
"max_norm": 1.0, | ||
"n_epochs": 3, | ||
"personality_permutations":1, | ||
"eval_before_start": false, | ||
"device": "cuda:0", | ||
"fp16": "", | ||
"local_rank": -1, | ||
"log_dir": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"dataset_path": "/home/rohola/data/daily_dialog_topic/daily_dialog.json" , | ||
"dataset_cache": "caches/daily_dialog_multihead", | ||
"model_checkpoint": "openai-gpt", | ||
"num_candidates": 2, | ||
"do_lower_case": true, | ||
"max_history": 2, | ||
"train_batch_size": 1, | ||
"valid_batch_size": 1, | ||
"gradient_accumulation_steps": 8, | ||
"lr": 6.25e-5, | ||
"warmup_proportion": 0.1, | ||
"lm_coef": 1.0, | ||
"mc_coef": 1.0, | ||
"max_norm": 1.0, | ||
"n_epochs": 3, | ||
"personality_permutations":1, | ||
"eval_before_start": false, | ||
"device": "cuda:0", | ||
"fp16": "", | ||
"local_rank": -1, | ||
"log_dir": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# Copyright (c) 2019-present, HuggingFace Inc. | ||
# All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. | ||
import logging | ||
from pprint import pformat | ||
from collections import defaultdict | ||
from itertools import chain | ||
|
||
import torch | ||
from torch.nn.parallel import DistributedDataParallel | ||
from torch.utils.data import DataLoader, TensorDataset | ||
|
||
from config import Config | ||
from pytorch_pretrained_bert import (OpenAIAdam, OpenAIGPTDoubleHeadLMEmotionRecognitionModel, OpenAIGPTTokenizer, | ||
GPT2DoubleHeadsModel, GPT2Tokenizer, WEIGHTS_NAME, CONFIG_NAME, | ||
BertModel, BertTokenizer) | ||
|
||
from utils import get_dataset, get_dataset_for_daily_dialog | ||
|
||
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", | ||
"<no_emotion>", "<happiness>", "<surprise>", "<sadness>", "<disgust>", "<anger>", "<fear>", | ||
"<directive>", "<inform>", "<commissive>", "<question>", | ||
"<pad>"] | ||
MODEL_INPUTS = ["input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids", "token_emotion_ids"] | ||
PADDED_INPUTS = ["input_ids", "lm_labels", "token_type_ids", "token_emotion_ids"] | ||
|
||
logger = logging.getLogger(__file__) | ||
|
||
def average_distributed_scalar(scalar, config): | ||
""" Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """ | ||
if config.local_rank == -1: | ||
return scalar | ||
scalar_t = torch.tensor(scalar, dtype=torch.float, device=config.device) / torch.distributed.get_world_size() | ||
torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) | ||
return scalar_t.item() | ||
|
||
|
||
def pad_dataset(dataset, padding=0): | ||
""" Pad the dataset. This could be optimized by defining a Dataset class and padd only batches but this is simpler. """ | ||
max_l = max(len(x) for x in dataset["input_ids"]) | ||
for name in PADDED_INPUTS: | ||
dataset[name] = [x + [padding if name != "lm_labels" else -1] * (max_l - len(x)) for x in dataset[name]] | ||
return dataset | ||
|
||
|
||
def get_emotion_label(tokenizer, candidate_emotion): | ||
_, _, _, _, no_emotion_id, happiness_id, surprise_id, sadness_id, disgust_id, anger_id, fear_id, _, _, _, _, _ = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) | ||
if candidate_emotion == happiness_id: | ||
return 0 | ||
elif candidate_emotion == surprise_id: | ||
return 1 | ||
elif candidate_emotion == sadness_id: | ||
return 2 | ||
elif candidate_emotion == disgust_id: | ||
return 3 | ||
elif candidate_emotion == anger_id: | ||
return 4 | ||
elif candidate_emotion == fear_id: | ||
return 5 | ||
elif candidate_emotion == no_emotion_id: | ||
return 6 | ||
|
||
|
||
def build_input_from_segments(history, emotions, reply, true_emotion, tokenizer, with_eos=True): | ||
""" Build a sequence of input from 3 segments: persona, history and last reply """ | ||
bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:4]) | ||
#tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]) | ||
|
||
instance = {} | ||
# sequence = [[bos] + history[0] + list(chain(*history[1:]))] + [reply + ([eos] if with_eos else [])] #seq = [personas, history, reply] concatenate all persona sentences | ||
sequence = [[bos] + history[0]] + history[1:] + [reply + ([eos] if with_eos else [])] | ||
sequence = [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence)] | ||
|
||
instance["input_ids"] = list(chain(*sequence)) | ||
instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s] # the last for is for repeating the speaker1 and speaker2 for all tokens | ||
#instance["token_emotion_ids"] = [emotions[i] for i, s in enumerate(sequence[:-1]) for _ in s] + [true_emotion] * len(sequence[-1]) | ||
instance["token_emotion_ids"] = [emotions[i] for i, s in enumerate(sequence[:-1]) for _ in s] | ||
|
||
instance["mc_token_ids"] = len(instance["input_ids"]) - 1 | ||
instance["mc_labels"] = get_emotion_label(tokenizer, true_emotion) | ||
instance["lm_labels"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [-1] + sequence[-1][1:] #all -1 except for reply, reply is just the ids | ||
return instance, sequence | ||
|
||
|
||
def get_data_loaders(config, tokenizer): | ||
""" Prepare the dataset for training and evaluation """ | ||
personachat = get_dataset_for_daily_dialog(tokenizer, config.dataset_path, config.dataset_cache, SPECIAL_TOKENS) | ||
|
||
#personachat["train"] = personachat["train"][:100] | ||
#personachat["valid"] = personachat["valid"][:10] | ||
|
||
logger.info("Build inputs and labels") | ||
datasets = {"train": defaultdict(list), "valid": defaultdict(list)} | ||
c = 0 | ||
for dataset_name, dataset in personachat.items(): | ||
num_candidates = 2#len(dataset[0]["utterances"][0]["candidates"]) | ||
if config.num_candidates > 0 and dataset_name == 'train': | ||
num_candidates = min(config.num_candidates, num_candidates) | ||
for dialog in dataset: | ||
for utterance in dialog["utterances"]: | ||
history = utterance["history"][-(2 * config.max_history + 1):] | ||
emotions = utterance["emotion"][-(2 * config.max_history + 1):] | ||
reply = utterance["candidates"][-1] | ||
true_emotion = utterance['candidates_emotions'][-1] | ||
if true_emotion == tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)[4]: | ||
continue | ||
instance, _ = build_input_from_segments(history, | ||
emotions, | ||
reply, | ||
true_emotion, | ||
tokenizer) | ||
|
||
if len(instance["input_ids"]) > 310: | ||
truncated_history = [hist[:10] for hist in history] | ||
truncated_candidate = reply[:10] | ||
true_emotion = utterance['candidates_emotions'][-1] | ||
instance, _ = build_input_from_segments(truncated_history, | ||
emotions, | ||
truncated_candidate, | ||
true_emotion, | ||
tokenizer) | ||
c+=1 | ||
|
||
for input_name, input_array in instance.items(): | ||
datasets[dataset_name][input_name].append(input_array) | ||
|
||
#datasets[dataset_name]["mc_labels"].append(num_candidates - 1) | ||
datasets[dataset_name]["n_candidates"] = num_candidates | ||
print(c) | ||
logger.info("Pad inputs and convert to Tensor") | ||
tensor_datasets = {"train": [], "valid": []} | ||
for dataset_name, dataset in datasets.items(): | ||
dataset = pad_dataset(dataset, padding=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1])) | ||
for input_name in MODEL_INPUTS: | ||
tensor = torch.tensor(dataset[input_name]) | ||
#if input_name != "mc_labels": | ||
# tensor = tensor.view((-1, datasets[dataset_name]["n_candidates"]) + tensor.shape[1:]) | ||
tensor_datasets[dataset_name].append(tensor) | ||
|
||
logger.info("Build train and validation dataloaders") | ||
train_dataset, valid_dataset = TensorDataset(*tensor_datasets["train"]), TensorDataset(*tensor_datasets["valid"]) | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if config.distributed else None | ||
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if config.distributed else None | ||
train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=config.train_batch_size, shuffle=False) | ||
valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=config.valid_batch_size, shuffle=False) | ||
|
||
logger.info("Train dataset (Batch, Candidates, Seq length): {}".format(train_dataset.tensors[0].shape)) | ||
logger.info("Valid dataset (Batch, Candidates, Seq length): {}".format(valid_dataset.tensors[0].shape)) | ||
return train_loader, valid_loader, train_sampler, valid_sampler | ||
|
||
|
||
def train(): | ||
config_file = "configs/train_full_pipeline_config.json" | ||
config = Config.from_json_file(config_file) | ||
|
||
# logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes | ||
logging.basicConfig(level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) | ||
logger.warning("Running process %d", config.local_rank) # This is a logger.warning: it will be printed by all distributed processes | ||
logger.info("Arguments: %s", pformat(config)) | ||
|
||
# Initialize distributed training if needed | ||
config.distributed = (config.local_rank != -1) | ||
if config.distributed: | ||
torch.cuda.set_device(config.local_rank) | ||
config.device = torch.device("cuda", config.local_rank) | ||
torch.distributed.init_process_group(backend='nccl', init_method='env://') | ||
|
||
logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") | ||
tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer | ||
tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) | ||
model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadLMEmotionRecognitionModel | ||
model = model_class.from_pretrained(config.model_checkpoint) | ||
tokenizer.set_special_tokens(SPECIAL_TOKENS) | ||
model.set_num_special_tokens(len(SPECIAL_TOKENS)) | ||
model.to(config.device) | ||
optimizer = OpenAIAdam(model.parameters(), lr=config.lr) | ||
|
||
# Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) | ||
if config.fp16: | ||
from apex import amp # Apex is only required if we use fp16 training | ||
model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) | ||
if config.distributed: | ||
model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) | ||
|
||
logger.info("Prepare datasets") | ||
train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(config, tokenizer) | ||
|
||
# Evaluation function and evaluator (evaluator output is the input of the metrics) | ||
model.eval() | ||
num_correct = 0 | ||
num_all = len(val_loader) | ||
for batch in val_loader: | ||
with torch.no_grad(): | ||
batch = tuple(input_tensor.to(config.device) for input_tensor in batch) | ||
input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = batch | ||
|
||
model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids) | ||
lm_logits, mc_logits = model_outputs[0], model_outputs[1] # So we can also use GPT2 outputs | ||
|
||
indices = torch.argmax(mc_logits, dim=1) | ||
|
||
correct = torch.eq(indices, mc_labels).view(-1) | ||
num_correct += torch.sum(correct).item() | ||
|
||
print(num_correct / num_all) | ||
|
||
|
||
if __name__ == "__main__": | ||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# # Copyright (c) 2019-present, HuggingFace Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import logging | ||
import random | ||
from argparse import ArgumentParser | ||
from itertools import chain | ||
from pprint import pformat | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from tqdm import tqdm | ||
|
||
from config import InteractConfig | ||
from pytorch_pretrained_bert import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer | ||
from utils import download_pretrained_model, get_dataset, _bleu, _f1_score | ||
|
||
|
||
|
||
def build_input_from_segments(persona, history, reply, tokenizer, SPECIAL_TOKENS, lm_labels=False, with_eos=True): | ||
""" Build a sequence of input from 3 segments: persona, history and last reply """ | ||
bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) | ||
|
||
instance = {} | ||
sequence = [[bos] + list(chain(*persona))] + history + [ | ||
reply + ([eos] if with_eos else [])] # seq = [personas, history, reply] concatenate all persona sentences | ||
sequence = [sequence[0]] + [[speaker2 if (len(sequence) - i) % 2 else speaker1] + s for i, s in | ||
enumerate(sequence[1:])] | ||
|
||
instance["input_ids"] = list(chain(*sequence)) | ||
instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in | ||
s] # the last for is for repeating the speaker1 and speaker2 for all tokens | ||
instance["mc_token_ids"] = len(instance["input_ids"]) - 1 | ||
instance["lm_labels"] = [-1] * len(instance["input_ids"]) | ||
if lm_labels: | ||
instance["lm_labels"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [-1] + sequence[-1][1:] # all -1 except for reply, reply is just the ids | ||
return instance, sequence | ||
|
||
|
||
|
||
def top_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-float('Inf')): | ||
""" Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering | ||
Args: | ||
logits: logits distribution shape (..., vocabulary size) | ||
top_k: <=0: no filtering, >0: keep only top k tokens with highest probability. | ||
top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset | ||
whose total probability mass is greater than or equal to the threshold top_p. | ||
In practice, we select the highest probability tokens whose cumulative probability mass exceeds | ||
the threshold top_p. | ||
threshold: a minimal threshold to keep logits | ||
""" | ||
top_k = min(top_k, logits.size(-1)) | ||
if top_k > 0: | ||
# Remove all tokens with a probability less than the last token in the top-k tokens | ||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | ||
logits[indices_to_remove] = filter_value | ||
|
||
if top_p > 0.0: | ||
# Compute cumulative probabilities of sorted tokens | ||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
# Remove tokens with cumulative probability above the threshold | ||
sorted_indices_to_remove = cumulative_probabilities > top_p | ||
# Shift the indices to the right to keep also the first token above the threshold | ||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
sorted_indices_to_remove[..., 0] = 0 | ||
|
||
# Back to unsorted indices and set them to -infinity | ||
indices_to_remove = sorted_indices[sorted_indices_to_remove] | ||
logits[indices_to_remove] = filter_value | ||
|
||
indices_to_remove = logits < threshold | ||
logits[indices_to_remove] = filter_value | ||
|
||
return logits | ||
|
||
|
||
def get_emotions(dataset): | ||
|
||
|
||
for data in tqdm(dataset['valid']): | ||
utterances = data['utterances'] | ||
|
||
for utterance in utterances: | ||
true_emotion = utterance["emotion"] | ||
|
||
|
||
def calculate_metrics(args, model, tokenizer, dataset, special_tokens): | ||
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) | ||
|
||
all_blues = [] | ||
all_f1_scores = [] | ||
all_true_sentences = [] | ||
all_predicted_sentences = [] | ||
for data in tqdm(dataset['valid']): | ||
personality = data['personality'] | ||
utterances = data['utterances'] | ||
|
||
#utterance = utterances[-1] #only the longest conversaion | ||
for utterance in utterances: | ||
true_label = utterance['candidates'][-1] | ||
history = utterance['history'] | ||
predicted_output = [] | ||
for i in range(args.max_length): | ||
instance, _ = build_input_from_segments(personality, history, predicted_output, tokenizer, special_tokens, with_eos=False) | ||
|
||
try: | ||
|
||
if len(instance["input_ids"]) > 310: | ||
truncated_history = [hist[:5] for hist in history] | ||
instance, _ = build_input_from_segments(personality, truncated_history, predicted_output, tokenizer, special_tokens, with_eos=False) | ||
|
||
input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0) | ||
token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0) | ||
|
||
logits = model(input_ids, token_type_ids=token_type_ids) | ||
except: | ||
print("exception") | ||
continue | ||
|
||
if "gpt2" == args.model: | ||
logits = logits[0] | ||
logits = logits[0, -1, :] / args.temperature | ||
logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p) | ||
probs = F.softmax(logits, dim=-1) | ||
|
||
prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1) | ||
# if i < args.min_length and prev.item() in special_tokens_ids: | ||
# k=0 | ||
# while prev.item() in special_tokens_ids and k < 100: | ||
# prev = torch.multinomial(probs, num_samples=1) | ||
# k+=1 | ||
|
||
if i < args.min_length: | ||
prev = torch.multinomial(probs, num_samples=1) | ||
|
||
# if prev.item() in special_tokens_ids: | ||
# break | ||
predicted_output.append(prev.item()) | ||
|
||
predicted_sentence = tokenizer.decode(predicted_output, skip_special_tokens=True) | ||
true_sentence = tokenizer.decode(true_label, skip_special_tokens=True) | ||
#looks like zero gives the best results | ||
|
||
all_predicted_sentences.append(predicted_sentence) | ||
all_true_sentences.append(true_sentence) | ||
|
||
bleus = [_bleu(predicted_sentence, [true_sentence], method="method"+str(i)) for i in [0,1,2,3,5]] | ||
#bleu = _bleu(predicted_sentence, [true_sentence]) | ||
f1_score = _f1_score(predicted_sentence, [true_sentence]) | ||
#print(f1_score) | ||
all_blues.append(bleus) | ||
all_f1_scores.append(f1_score) | ||
#compare predicted and label with bleu | ||
|
||
|
||
print("avg bleu", np.array(all_blues).mean(axis=0)) | ||
print("avg f1 score", np.mean(all_f1_scores)) | ||
print("max bleu", np.array(all_blues).max(axis=0)) | ||
|
||
|
||
def run(): | ||
config_file = "configs/interact_config.json" | ||
config = InteractConfig.from_json_file(config_file) | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__file__) | ||
logger.info(pformat(config)) | ||
|
||
if config.model_checkpoint == "": | ||
config.model_checkpoint = download_pretrained_model() | ||
|
||
random.seed(config.seed) | ||
torch.random.manual_seed(config.seed) | ||
torch.cuda.manual_seed(config.seed) | ||
|
||
logger.info("Get pretrained model and tokenizer") | ||
tokenizer_class = GPT2Tokenizer if "gpt2" == config.model else OpenAIGPTTokenizer | ||
tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) | ||
model_class = GPT2LMHeadModel if "gpt2" == config.model else OpenAIGPTLMHeadModel | ||
model = model_class.from_pretrained(config.model_checkpoint) | ||
|
||
model.to(config.device) | ||
model.eval() | ||
|
||
dataset = get_dataset(tokenizer, config.dataset_path, config.dataset_cache) | ||
|
||
special_tokens = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"] | ||
calculate_metrics(config, model, tokenizer, dataset, special_tokens) | ||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# # Copyright (c) 2019-present, HuggingFace Inc. | ||
# All rights reserved. | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import logging | ||
import random | ||
from argparse import ArgumentParser | ||
from itertools import chain | ||
from pprint import pformat | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from config import InteractConfig | ||
from pytorch_pretrained_bert import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer, \ | ||
BertTokenizer | ||
from pytorch_pretrained_bert.modeling import BertLMHeadModel | ||
from utils import get_dataset_personalities, download_pretrained_model, get_dataset | ||
|
||
|
||
def build_input_from_segments(history, reply, tokenizer, SPECIAL_TOKENS, lm_labels=False, with_eos=True): | ||
""" Build a sequence of input from 3 segments: persona, history and last reply """ | ||
bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1]) | ||
persona = [] | ||
instance = {} | ||
sequence = [[bos] + list(chain(*persona))] + history + [ | ||
reply + ([eos] if with_eos else [])] # seq = [personas, history, reply] concatenate all persona sentences | ||
sequence = [sequence[0]] + [[speaker2 if (len(sequence) - i) % 2 else speaker1] + s for i, s in | ||
enumerate(sequence[1:])] | ||
|
||
instance["input_ids"] = list(chain(*sequence)) | ||
instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in | ||
s] # the last for is for repeating the speaker1 and speaker2 for all tokens | ||
instance["mc_token_ids"] = len(instance["input_ids"]) - 1 | ||
instance["lm_labels"] = [-1] * len(instance["input_ids"]) | ||
if lm_labels: | ||
instance["lm_labels"] = ([-1] * sum(len(s) for s in sequence[:-1])) + [-1] + sequence[-1][1:] # all -1 except for reply, reply is just the ids | ||
return instance, sequence | ||
|
||
|
||
def top_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-float('Inf')): | ||
""" Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering | ||
Args: | ||
logits: logits distribution shape (..., vocabulary size) | ||
top_k: <=0: no filtering, >0: keep only top k tokens with highest probability. | ||
top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset | ||
whose total probability mass is greater than or equal to the threshold top_p. | ||
In practice, we select the highest probability tokens whose cumulative probability mass exceeds | ||
the threshold top_p. | ||
threshold: a minimal threshold to keep logits | ||
""" | ||
top_k = min(top_k, logits.size(-1)) | ||
if top_k > 0: | ||
# Remove all tokens with a probability less than the last token in the top-k tokens | ||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | ||
logits[indices_to_remove] = filter_value | ||
|
||
if top_p > 0.0: | ||
# Compute cumulative probabilities of sorted tokens | ||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | ||
cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | ||
|
||
# Remove tokens with cumulative probability above the threshold | ||
sorted_indices_to_remove = cumulative_probabilities > top_p | ||
# Shift the indices to the right to keep also the first token above the threshold | ||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | ||
sorted_indices_to_remove[..., 0] = 0 | ||
|
||
# Back to unsorted indices and set them to -infinity | ||
indices_to_remove = sorted_indices[sorted_indices_to_remove] | ||
logits[indices_to_remove] = filter_value | ||
|
||
indices_to_remove = logits < threshold | ||
logits[indices_to_remove] = filter_value | ||
|
||
return logits | ||
|
||
|
||
def sample_sequence(history, tokenizer, model, args, SPECIAL_TOKENS, current_output=None): | ||
special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) | ||
|
||
if current_output is None: | ||
current_output = [] | ||
|
||
for i in range(args.max_length): | ||
instance, sequence = build_input_from_segments(history, current_output, tokenizer, SPECIAL_TOKENS, | ||
with_eos=False) | ||
|
||
input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0) | ||
token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0) | ||
|
||
logits = model(input_ids, token_type_ids=token_type_ids) | ||
|
||
if "gpt2" == args.model: | ||
logits = logits[0] | ||
logits = logits[0, -1, :] / args.temperature | ||
logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p) | ||
probs = F.softmax(logits, dim=-1) | ||
|
||
prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1) | ||
if i < args.min_length and prev.item() in special_tokens_ids: | ||
while prev.item() in special_tokens_ids: | ||
prev = torch.multinomial(probs, num_samples=1) | ||
|
||
if prev.item() in special_tokens_ids: | ||
break | ||
current_output.append(prev.item()) | ||
|
||
return current_output | ||
|
||
|
||
def run(): | ||
config_file = "configs/interact_config.json" | ||
config = InteractConfig.from_json_file(config_file) | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__file__) | ||
logger.info(pformat(config)) | ||
|
||
if config.model_checkpoint == "": | ||
config.model_checkpoint = download_pretrained_model() | ||
|
||
torch.random.manual_seed(config.seed) | ||
torch.cuda.manual_seed(config.seed) | ||
|
||
logger.info("Get pretrained model and tokenizer") | ||
if config.model == "bert": | ||
tokenizer_class = BertTokenizer | ||
model_class = BertLMHeadModel | ||
elif config.model == "gpt2": | ||
tokenizer_class = GPT2Tokenizer | ||
model_class = GPT2LMHeadModel | ||
else: | ||
tokenizer_class = OpenAIGPTTokenizer | ||
model_class = OpenAIGPTLMHeadModel | ||
|
||
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"] | ||
|
||
tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) | ||
model = model_class.from_pretrained(config.model_checkpoint) | ||
|
||
model.to(config.device) | ||
model.eval() | ||
|
||
history = [] | ||
while True: | ||
raw_text = input(">>> ") | ||
while not raw_text: | ||
print('Prompt should not be empty!') | ||
raw_text = input(">>> ") | ||
history.append(tokenizer.encode(raw_text)) | ||
with torch.no_grad(): | ||
out_ids = sample_sequence(history, tokenizer, model, config, SPECIAL_TOKENS) | ||
history.append(out_ids) | ||
history = history[-(2 * config.max_history + 1):] | ||
out_text = tokenizer.decode(out_ids, skip_special_tokens=True) | ||
print(out_text) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
__version__ = "0.6.2" | ||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer | ||
from .tokenization_openai import OpenAIGPTTokenizer | ||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) | ||
from .tokenization_gpt2 import GPT2Tokenizer | ||
|
||
from .modeling import (BertConfig, BertModel, BertForPreTraining, | ||
BertForMaskedLM, BertForNextSentencePrediction, | ||
BertForSequenceClassification, BertForMultipleChoice, | ||
BertForTokenClassification, BertForQuestionAnswering, | ||
load_tf_weights_in_bert) | ||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, | ||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTDoubleHeadLMEmotionRecognitionModel, | ||
OpenAIGPTForEmotionDetection, | ||
OpenAIGPTMultiHeadModel, | ||
load_tf_weights_in_openai_gpt) | ||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, | ||
load_tf_weights_in_transfo_xl) | ||
from .modeling_gpt2 import (GPT2Config, GPT2Model, | ||
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2MultipleChoiceHead, | ||
load_tf_weights_in_gpt2) | ||
|
||
from .optimization import BertAdam | ||
from .optimization_openai import OpenAIAdam | ||
|
||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path, WEIGHTS_NAME, CONFIG_NAME |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# coding: utf8 | ||
def main(): | ||
import sys | ||
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ | ||
"convert_tf_checkpoint_to_pytorch", | ||
"convert_openai_checkpoint", | ||
"convert_transfo_xl_checkpoint", | ||
"convert_gpt2_checkpoint", | ||
]: | ||
print( | ||
"Should be used as one of: \n" | ||
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" | ||
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" | ||
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" | ||
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") | ||
else: | ||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": | ||
try: | ||
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch | ||
except ImportError: | ||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " | ||
"In that case, it requires TensorFlow to be installed. Please see " | ||
"https://www.tensorflow.org/install/ for installation instructions.") | ||
raise | ||
|
||
if len(sys.argv) != 5: | ||
# pylint: disable=line-too-long | ||
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") | ||
else: | ||
PYTORCH_DUMP_OUTPUT = sys.argv.pop() | ||
TF_CONFIG = sys.argv.pop() | ||
TF_CHECKPOINT = sys.argv.pop() | ||
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) | ||
elif sys.argv[1] == "convert_openai_checkpoint": | ||
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch | ||
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] | ||
PYTORCH_DUMP_OUTPUT = sys.argv[3] | ||
if len(sys.argv) == 5: | ||
OPENAI_GPT_CONFIG = sys.argv[4] | ||
else: | ||
OPENAI_GPT_CONFIG = "" | ||
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, | ||
OPENAI_GPT_CONFIG, | ||
PYTORCH_DUMP_OUTPUT) | ||
elif sys.argv[1] == "convert_transfo_xl_checkpoint": | ||
try: | ||
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch | ||
except ImportError: | ||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " | ||
"In that case, it requires TensorFlow to be installed. Please see " | ||
"https://www.tensorflow.org/install/ for installation instructions.") | ||
raise | ||
|
||
if 'ckpt' in sys.argv[2].lower(): | ||
TF_CHECKPOINT = sys.argv[2] | ||
TF_DATASET_FILE = "" | ||
else: | ||
TF_DATASET_FILE = sys.argv[2] | ||
TF_CHECKPOINT = "" | ||
PYTORCH_DUMP_OUTPUT = sys.argv[3] | ||
if len(sys.argv) == 5: | ||
TF_CONFIG = sys.argv[4] | ||
else: | ||
TF_CONFIG = "" | ||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) | ||
else: | ||
try: | ||
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch | ||
except ImportError: | ||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " | ||
"In that case, it requires TensorFlow to be installed. Please see " | ||
"https://www.tensorflow.org/install/ for installation instructions.") | ||
raise | ||
|
||
TF_CHECKPOINT = sys.argv[2] | ||
PYTORCH_DUMP_OUTPUT = sys.argv[3] | ||
if len(sys.argv) == 5: | ||
TF_CONFIG = sys.argv[4] | ||
else: | ||
TF_CONFIG = "" | ||
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) | ||
if __name__ == '__main__': | ||
main() |
72 changes: 72 additions & 0 deletions
72
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Convert OpenAI GPT checkpoint.""" | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import argparse | ||
from io import open | ||
|
||
import torch | ||
|
||
from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, | ||
GPT2Config, | ||
GPT2Model, | ||
load_tf_weights_in_gpt2) | ||
|
||
|
||
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): | ||
# Construct model | ||
if gpt2_config_file == "": | ||
config = GPT2Config() | ||
else: | ||
config = GPT2Config(gpt2_config_file) | ||
model = GPT2Model(config) | ||
|
||
# Load weights from numpy | ||
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) | ||
|
||
# Save pytorch-model | ||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME | ||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME | ||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) | ||
torch.save(model.state_dict(), pytorch_weights_dump_path) | ||
print("Save configuration file to {}".format(pytorch_config_dump_path)) | ||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: | ||
f.write(config.to_json_string()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
## Required parameters | ||
parser.add_argument("--gpt2_checkpoint_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path the TensorFlow checkpoint path.") | ||
parser.add_argument("--pytorch_dump_folder_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path to the output PyTorch model.") | ||
parser.add_argument("--gpt2_config_file", | ||
default = "", | ||
type = str, | ||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" | ||
"This specifies the model architecture.") | ||
args = parser.parse_args() | ||
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, | ||
args.gpt2_config_file, | ||
args.pytorch_dump_folder_path) |
72 changes: 72 additions & 0 deletions
72
pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Convert OpenAI GPT checkpoint.""" | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import argparse | ||
from io import open | ||
|
||
import torch | ||
|
||
from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, | ||
OpenAIGPTConfig, | ||
OpenAIGPTModel, | ||
load_tf_weights_in_openai_gpt) | ||
|
||
|
||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): | ||
# Construct model | ||
if openai_config_file == "": | ||
config = OpenAIGPTConfig() | ||
else: | ||
config = OpenAIGPTConfig(openai_config_file) | ||
model = OpenAIGPTModel(config) | ||
|
||
# Load weights from numpy | ||
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) | ||
|
||
# Save pytorch-model | ||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME | ||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME | ||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) | ||
torch.save(model.state_dict(), pytorch_weights_dump_path) | ||
print("Save configuration file to {}".format(pytorch_config_dump_path)) | ||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: | ||
f.write(config.to_json_string()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
## Required parameters | ||
parser.add_argument("--openai_checkpoint_folder_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path the TensorFlow checkpoint path.") | ||
parser.add_argument("--pytorch_dump_folder_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path to the output PyTorch model.") | ||
parser.add_argument("--openai_config_file", | ||
default = "", | ||
type = str, | ||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" | ||
"This specifies the model architecture.") | ||
args = parser.parse_args() | ||
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, | ||
args.openai_config_file, | ||
args.pytorch_dump_folder_path) |
66 changes: 66 additions & 0 deletions
66
pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Convert BERT checkpoint.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import re | ||
import argparse | ||
import tensorflow as tf | ||
import torch | ||
import numpy as np | ||
|
||
from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert | ||
|
||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): | ||
# Initialise PyTorch model | ||
config = BertConfig.from_json_file(bert_config_file) | ||
print("Building PyTorch model from configuration: {}".format(str(config))) | ||
model = BertForPreTraining(config) | ||
|
||
# Load weights from tf checkpoint | ||
load_tf_weights_in_bert(model, tf_checkpoint_path) | ||
|
||
# Save pytorch-model | ||
print("Save PyTorch model to {}".format(pytorch_dump_path)) | ||
torch.save(model.state_dict(), pytorch_dump_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
## Required parameters | ||
parser.add_argument("--tf_checkpoint_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path the TensorFlow checkpoint path.") | ||
parser.add_argument("--bert_config_file", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "The config json file corresponding to the pre-trained BERT model. \n" | ||
"This specifies the model architecture.") | ||
parser.add_argument("--pytorch_dump_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path to the output PyTorch model.") | ||
args = parser.parse_args() | ||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, | ||
args.bert_config_file, | ||
args.pytorch_dump_path) |
116 changes: 116 additions & 0 deletions
116
pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Convert Transformer XL checkpoint and datasets.""" | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import argparse | ||
import os | ||
import sys | ||
from io import open | ||
|
||
import torch | ||
|
||
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils | ||
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, | ||
WEIGHTS_NAME, | ||
TransfoXLConfig, | ||
TransfoXLLMHeadModel, | ||
load_tf_weights_in_transfo_xl) | ||
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, | ||
VOCAB_NAME) | ||
|
||
if sys.version_info[0] == 2: | ||
import cPickle as pickle | ||
else: | ||
import pickle | ||
|
||
# We do this to be able to load python 2 datasets pickles | ||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 | ||
data_utils.Vocab = data_utils.TransfoXLTokenizer | ||
data_utils.Corpus = data_utils.TransfoXLCorpus | ||
sys.modules['data_utils'] = data_utils | ||
sys.modules['vocabulary'] = data_utils | ||
|
||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, | ||
transfo_xl_config_file, | ||
pytorch_dump_folder_path, | ||
transfo_xl_dataset_file): | ||
if transfo_xl_dataset_file: | ||
# Convert a pre-processed corpus (see original TensorFlow repo) | ||
with open(transfo_xl_dataset_file, "rb") as fp: | ||
corpus = pickle.load(fp, encoding="latin1") | ||
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) | ||
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME | ||
print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) | ||
corpus_vocab_dict = corpus.vocab.__dict__ | ||
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) | ||
|
||
corpus_dict_no_vocab = corpus.__dict__ | ||
corpus_dict_no_vocab.pop('vocab', None) | ||
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME | ||
print("Save dataset to {}".format(pytorch_dataset_dump_path)) | ||
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) | ||
|
||
if tf_checkpoint_path: | ||
# Convert a pre-trained TensorFlow model | ||
config_path = os.path.abspath(transfo_xl_config_file) | ||
tf_path = os.path.abspath(tf_checkpoint_path) | ||
|
||
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) | ||
# Initialise PyTorch model | ||
if transfo_xl_config_file == "": | ||
config = TransfoXLConfig() | ||
else: | ||
config = TransfoXLConfig(transfo_xl_config_file) | ||
print("Building PyTorch model from configuration: {}".format(str(config))) | ||
model = TransfoXLLMHeadModel(config) | ||
|
||
model = load_tf_weights_in_transfo_xl(model, config, tf_path) | ||
# Save pytorch-model | ||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) | ||
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) | ||
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) | ||
torch.save(model.state_dict(), pytorch_weights_dump_path) | ||
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) | ||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: | ||
f.write(config.to_json_string()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--pytorch_dump_folder_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path to the folder to store the PyTorch model or dataset/vocab.") | ||
parser.add_argument("--tf_checkpoint_path", | ||
default = "", | ||
type = str, | ||
help = "An optional path to a TensorFlow checkpoint path to be converted.") | ||
parser.add_argument("--transfo_xl_config_file", | ||
default = "", | ||
type = str, | ||
help = "An optional config json file corresponding to the pre-trained BERT model. \n" | ||
"This specifies the model architecture.") | ||
parser.add_argument("--transfo_xl_dataset_file", | ||
default = "", | ||
type = str, | ||
help = "An optional dataset file to be converted in a vocabulary.") | ||
args = parser.parse_args() | ||
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, | ||
args.transfo_xl_config_file, | ||
args.pytorch_dump_folder_path, | ||
args.transfo_xl_dataset_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
""" | ||
Utilities for working with the local dataset cache. | ||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp | ||
Copyright by the AllenNLP authors. | ||
""" | ||
from __future__ import (absolute_import, division, print_function, unicode_literals) | ||
|
||
import sys | ||
import json | ||
import logging | ||
import os | ||
import shutil | ||
import tempfile | ||
import fnmatch | ||
from functools import wraps | ||
from hashlib import sha256 | ||
import sys | ||
from io import open | ||
|
||
import boto3 | ||
import requests | ||
from botocore.exceptions import ClientError | ||
from tqdm import tqdm | ||
|
||
try: | ||
from torch.hub import _get_torch_home | ||
torch_cache_home = _get_torch_home() | ||
except ImportError: | ||
torch_cache_home = os.path.expanduser( | ||
os.getenv('TORCH_HOME', os.path.join( | ||
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) | ||
default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') | ||
|
||
try: | ||
from urllib.parse import urlparse | ||
except ImportError: | ||
from urlparse import urlparse | ||
|
||
try: | ||
from pathlib import Path | ||
PYTORCH_PRETRAINED_BERT_CACHE = Path( | ||
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) | ||
except (AttributeError, ImportError): | ||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', | ||
default_cache_path) | ||
|
||
CONFIG_NAME = "config.json" | ||
WEIGHTS_NAME = "pytorch_model.bin" | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
def url_to_filename(url, etag=None): | ||
""" | ||
Convert `url` into a hashed filename in a repeatable way. | ||
If `etag` is specified, append its hash to the url's, delimited | ||
by a period. | ||
""" | ||
url_bytes = url.encode('utf-8') | ||
url_hash = sha256(url_bytes) | ||
filename = url_hash.hexdigest() | ||
|
||
if etag: | ||
etag_bytes = etag.encode('utf-8') | ||
etag_hash = sha256(etag_bytes) | ||
filename += '.' + etag_hash.hexdigest() | ||
|
||
return filename | ||
|
||
|
||
def filename_to_url(filename, cache_dir=None): | ||
""" | ||
Return the url and etag (which may be ``None``) stored for `filename`. | ||
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | ||
""" | ||
if cache_dir is None: | ||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | ||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | ||
cache_dir = str(cache_dir) | ||
|
||
cache_path = os.path.join(cache_dir, filename) | ||
if not os.path.exists(cache_path): | ||
raise EnvironmentError("file {} not found".format(cache_path)) | ||
|
||
meta_path = cache_path + '.json' | ||
if not os.path.exists(meta_path): | ||
raise EnvironmentError("file {} not found".format(meta_path)) | ||
|
||
with open(meta_path, encoding="utf-8") as meta_file: | ||
metadata = json.load(meta_file) | ||
url = metadata['url'] | ||
etag = metadata['etag'] | ||
|
||
return url, etag | ||
|
||
|
||
def cached_path(url_or_filename, cache_dir=None): | ||
""" | ||
Given something that might be a URL (or might be a local path), | ||
determine which. If it's a URL, download the file and cache it, and | ||
return the path to the cached file. If it's already a local path, | ||
make sure the file exists and then return the path. | ||
""" | ||
if cache_dir is None: | ||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | ||
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): | ||
url_or_filename = str(url_or_filename) | ||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | ||
cache_dir = str(cache_dir) | ||
|
||
parsed = urlparse(url_or_filename) | ||
|
||
if parsed.scheme in ('http', 'https', 's3'): | ||
# URL, so get it from the cache (downloading if necessary) | ||
return get_from_cache(url_or_filename, cache_dir) | ||
elif os.path.exists(url_or_filename): | ||
# File, and it exists. | ||
return url_or_filename | ||
elif parsed.scheme == '': | ||
# File, but it doesn't exist. | ||
raise EnvironmentError("file {} not found".format(url_or_filename)) | ||
else: | ||
# Something unknown | ||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) | ||
|
||
|
||
def split_s3_path(url): | ||
"""Split a full s3 path into the bucket name and path.""" | ||
parsed = urlparse(url) | ||
if not parsed.netloc or not parsed.path: | ||
raise ValueError("bad s3 path {}".format(url)) | ||
bucket_name = parsed.netloc | ||
s3_path = parsed.path | ||
# Remove '/' at beginning of path. | ||
if s3_path.startswith("/"): | ||
s3_path = s3_path[1:] | ||
return bucket_name, s3_path | ||
|
||
|
||
def s3_request(func): | ||
""" | ||
Wrapper function for s3 requests in order to create more helpful error | ||
messages. | ||
""" | ||
|
||
@wraps(func) | ||
def wrapper(url, *args, **kwargs): | ||
try: | ||
return func(url, *args, **kwargs) | ||
except ClientError as exc: | ||
if int(exc.response["Error"]["Code"]) == 404: | ||
raise EnvironmentError("file {} not found".format(url)) | ||
else: | ||
raise | ||
|
||
return wrapper | ||
|
||
|
||
@s3_request | ||
def s3_etag(url): | ||
"""Check ETag on S3 object.""" | ||
s3_resource = boto3.resource("s3") | ||
bucket_name, s3_path = split_s3_path(url) | ||
s3_object = s3_resource.Object(bucket_name, s3_path) | ||
return s3_object.e_tag | ||
|
||
|
||
@s3_request | ||
def s3_get(url, temp_file): | ||
"""Pull a file directly from S3.""" | ||
s3_resource = boto3.resource("s3") | ||
bucket_name, s3_path = split_s3_path(url) | ||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) | ||
|
||
|
||
def http_get(url, temp_file): | ||
req = requests.get(url, stream=True) | ||
content_length = req.headers.get('Content-Length') | ||
total = int(content_length) if content_length is not None else None | ||
progress = tqdm(unit="B", total=total) | ||
for chunk in req.iter_content(chunk_size=1024): | ||
if chunk: # filter out keep-alive new chunks | ||
progress.update(len(chunk)) | ||
temp_file.write(chunk) | ||
progress.close() | ||
|
||
|
||
def get_from_cache(url, cache_dir=None): | ||
""" | ||
Given a URL, look for the corresponding dataset in the local cache. | ||
If it's not there, download it. Then return the path to the cached file. | ||
""" | ||
if cache_dir is None: | ||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | ||
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | ||
cache_dir = str(cache_dir) | ||
|
||
if not os.path.exists(cache_dir): | ||
os.makedirs(cache_dir) | ||
|
||
# Get eTag to add to filename, if it exists. | ||
if url.startswith("s3://"): | ||
etag = s3_etag(url) | ||
else: | ||
try: | ||
response = requests.head(url, allow_redirects=True) | ||
if response.status_code != 200: | ||
etag = None | ||
else: | ||
etag = response.headers.get("ETag") | ||
except EnvironmentError: | ||
etag = None | ||
|
||
if sys.version_info[0] == 2 and etag is not None: | ||
etag = etag.decode('utf-8') | ||
filename = url_to_filename(url, etag) | ||
|
||
# get cache path to put the file | ||
cache_path = os.path.join(cache_dir, filename) | ||
|
||
# If we don't have a connection (etag is None) and can't identify the file | ||
# try to get the last downloaded one | ||
if not os.path.exists(cache_path) and etag is None: | ||
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') | ||
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) | ||
if matching_files: | ||
cache_path = os.path.join(cache_dir, matching_files[-1]) | ||
|
||
if not os.path.exists(cache_path): | ||
# Download to temporary file, then copy to cache dir once finished. | ||
# Otherwise you get corrupt cache entries if the download gets interrupted. | ||
with tempfile.NamedTemporaryFile() as temp_file: | ||
logger.info("%s not found in cache, downloading to %s", url, temp_file.name) | ||
|
||
# GET file object | ||
if url.startswith("s3://"): | ||
s3_get(url, temp_file) | ||
else: | ||
http_get(url, temp_file) | ||
|
||
# we are copying the file before closing it, so flush to avoid truncation | ||
temp_file.flush() | ||
# shutil.copyfileobj() starts at the current position, so go to the start | ||
temp_file.seek(0) | ||
|
||
logger.info("copying %s to cache at %s", temp_file.name, cache_path) | ||
with open(cache_path, 'wb') as cache_file: | ||
shutil.copyfileobj(temp_file, cache_file) | ||
|
||
logger.info("creating metadata file for %s", cache_path) | ||
meta = {'url': url, 'etag': etag} | ||
meta_path = cache_path + '.json' | ||
with open(meta_path, 'w') as meta_file: | ||
output_string = json.dumps(meta) | ||
if sys.version_info[0] == 2 and isinstance(output_string, str): | ||
output_string = unicode(output_string, 'utf-8') # The beauty of python 2 | ||
meta_file.write(output_string) | ||
|
||
logger.info("removing temp file %s", temp_file.name) | ||
|
||
return cache_path | ||
|
||
|
||
def read_set_from_file(filename): | ||
''' | ||
Extract a de-duped collection (set) of text from a file. | ||
Expected file format is one item per line. | ||
''' | ||
collection = set() | ||
with open(filename, 'r', encoding='utf-8') as file_: | ||
for line in file_: | ||
collection.add(line.rstrip()) | ||
return collection | ||
|
||
|
||
def get_file_extension(path, dot=True, lower=True): | ||
ext = os.path.splitext(path)[1] | ||
ext = ext if dot else ext[1:] | ||
return ext.lower() if lower else ext |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
402 changes: 402 additions & 0 deletions
402
pytorch_pretrained_bert/modeling_transfo_xl_utilities.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""PyTorch optimization for BERT model.""" | ||
|
||
import math | ||
import torch | ||
from torch.optim import Optimizer | ||
from torch.optim.optimizer import required | ||
from torch.nn.utils import clip_grad_norm_ | ||
import logging | ||
import abc | ||
import sys | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
if sys.version_info >= (3, 4): | ||
ABC = abc.ABC | ||
else: | ||
ABC = abc.ABCMeta('ABC', (), {}) | ||
|
||
|
||
class _LRSchedule(ABC): | ||
""" Parent of all LRSchedules here. """ | ||
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense | ||
def __init__(self, warmup=0.002, t_total=-1, **kw): | ||
""" | ||
:param warmup: what fraction of t_total steps will be used for linear warmup | ||
:param t_total: how many training steps (updates) are planned | ||
:param kw: | ||
""" | ||
super(_LRSchedule, self).__init__(**kw) | ||
if t_total < 0: | ||
logger.warning("t_total value of {} results in schedule not being applied".format(t_total)) | ||
if not 0.0 <= warmup < 1.0 and not warmup == -1: | ||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) | ||
warmup = max(warmup, 0.) | ||
self.warmup, self.t_total = float(warmup), float(t_total) | ||
self.warned_for_t_total_at_progress = -1 | ||
|
||
def get_lr(self, step, nowarn=False): | ||
""" | ||
:param step: which of t_total steps we're on | ||
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps | ||
:return: learning rate multiplier for current update | ||
""" | ||
if self.t_total < 0: | ||
return 1. | ||
progress = float(step) / self.t_total | ||
ret = self.get_lr_(progress) | ||
# warning for exceeding t_total (only active with warmup_linear | ||
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress: | ||
logger.warning( | ||
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly." | ||
.format(ret, self.__class__.__name__)) | ||
self.warned_for_t_total_at_progress = progress | ||
# end warning | ||
return ret | ||
|
||
@abc.abstractmethod | ||
def get_lr_(self, progress): | ||
""" | ||
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress | ||
:return: learning rate multiplier for current update | ||
""" | ||
return 1. | ||
|
||
|
||
class ConstantLR(_LRSchedule): | ||
def get_lr_(self, progress): | ||
return 1. | ||
|
||
|
||
class WarmupCosineSchedule(_LRSchedule): | ||
""" | ||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. | ||
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve. | ||
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. | ||
""" | ||
warn_t_total = True | ||
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw): | ||
""" | ||
:param warmup: see LRSchedule | ||
:param t_total: see LRSchedule | ||
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1. | ||
:param kw: | ||
""" | ||
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw) | ||
self.cycles = cycles | ||
|
||
def get_lr_(self, progress): | ||
if progress < self.warmup: | ||
return progress / self.warmup | ||
else: | ||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup | ||
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress)) | ||
|
||
|
||
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule): | ||
""" | ||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. | ||
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying | ||
learning rate (with hard restarts). | ||
""" | ||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): | ||
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) | ||
assert(cycles >= 1.) | ||
|
||
def get_lr_(self, progress): | ||
if progress < self.warmup: | ||
return progress / self.warmup | ||
else: | ||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup | ||
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1))) | ||
return ret | ||
|
||
|
||
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule): | ||
""" | ||
All training progress is divided in `cycles` (default=1.) parts of equal length. | ||
Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1., | ||
followed by a learning rate decreasing from 1. to 0. following a cosine curve. | ||
""" | ||
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw): | ||
assert(warmup * cycles < 1.) | ||
warmup = warmup * cycles if warmup >= 0 else warmup | ||
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw) | ||
|
||
def get_lr_(self, progress): | ||
progress = progress * self.cycles % 1. | ||
if progress < self.warmup: | ||
return progress / self.warmup | ||
else: | ||
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup | ||
ret = 0.5 * (1. + math.cos(math.pi * progress)) | ||
return ret | ||
|
||
|
||
class WarmupConstantSchedule(_LRSchedule): | ||
""" | ||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. | ||
Keeps learning rate equal to 1. after warmup. | ||
""" | ||
def get_lr_(self, progress): | ||
if progress < self.warmup: | ||
return progress / self.warmup | ||
return 1. | ||
|
||
|
||
class WarmupLinearSchedule(_LRSchedule): | ||
""" | ||
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps. | ||
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps. | ||
""" | ||
warn_t_total = True | ||
def get_lr_(self, progress): | ||
if progress < self.warmup: | ||
return progress / self.warmup | ||
return max((progress - 1.) / (self.warmup - 1.), 0.) | ||
|
||
|
||
SCHEDULES = { | ||
None: ConstantLR, | ||
"none": ConstantLR, | ||
"warmup_cosine": WarmupCosineSchedule, | ||
"warmup_constant": WarmupConstantSchedule, | ||
"warmup_linear": WarmupLinearSchedule | ||
} | ||
|
||
|
||
class BertAdam(Optimizer): | ||
"""Implements BERT version of Adam algorithm with weight decay fix. | ||
Params: | ||
lr: learning rate | ||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 | ||
t_total: total number of training steps for the learning | ||
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1 | ||
schedule: schedule to use for the warmup (see above). | ||
Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below). | ||
If `None` or `'none'`, learning rate is always kept constant. | ||
Default : `'warmup_linear'` | ||
b1: Adams b1. Default: 0.9 | ||
b2: Adams b2. Default: 0.999 | ||
e: Adams epsilon. Default: 1e-6 | ||
weight_decay: Weight decay. Default: 0.01 | ||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 | ||
""" | ||
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', | ||
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs): | ||
if lr is not required and lr < 0.0: | ||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) | ||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: | ||
raise ValueError("Invalid schedule parameter: {}".format(schedule)) | ||
if not 0.0 <= b1 < 1.0: | ||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) | ||
if not 0.0 <= b2 < 1.0: | ||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) | ||
if not e >= 0.0: | ||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) | ||
# initialize schedule object | ||
if not isinstance(schedule, _LRSchedule): | ||
schedule_type = SCHEDULES[schedule] | ||
schedule = schedule_type(warmup=warmup, t_total=t_total) | ||
else: | ||
if warmup != -1 or t_total != -1: | ||
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " | ||
"Please specify custom warmup and t_total in _LRSchedule object.") | ||
defaults = dict(lr=lr, schedule=schedule, | ||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, | ||
max_grad_norm=max_grad_norm) | ||
super(BertAdam, self).__init__(params, defaults) | ||
|
||
def get_lr(self): | ||
lr = [] | ||
for group in self.param_groups: | ||
for p in group['params']: | ||
state = self.state[p] | ||
if len(state) == 0: | ||
return [0] | ||
lr_scheduled = group['lr'] | ||
lr_scheduled *= group['schedule'].get_lr(state['step']) | ||
lr.append(lr_scheduled) | ||
return lr | ||
|
||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
grad = p.grad.data | ||
if grad.is_sparse: | ||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | ||
|
||
state = self.state[p] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state['step'] = 0 | ||
# Exponential moving average of gradient values | ||
state['next_m'] = torch.zeros_like(p.data) | ||
# Exponential moving average of squared gradient values | ||
state['next_v'] = torch.zeros_like(p.data) | ||
|
||
next_m, next_v = state['next_m'], state['next_v'] | ||
beta1, beta2 = group['b1'], group['b2'] | ||
|
||
# Add grad clipping | ||
if group['max_grad_norm'] > 0: | ||
clip_grad_norm_(p, group['max_grad_norm']) | ||
|
||
# Decay the first and second moment running average coefficient | ||
# In-place operations to update the averages at the same time | ||
next_m.mul_(beta1).add_(1 - beta1, grad) | ||
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||
update = next_m / (next_v.sqrt() + group['e']) | ||
|
||
# Just adding the square of the weights to the loss function is *not* | ||
# the correct way of using L2 regularization/weight decay with Adam, | ||
# since that will interact with the m and v parameters in strange ways. | ||
# | ||
# Instead we want to decay the weights in a manner that doesn't interact | ||
# with the m/v parameters. This is equivalent to adding the square | ||
# of the weights to the loss with plain (non-momentum) SGD. | ||
if group['weight_decay'] > 0.0: | ||
update += group['weight_decay'] * p.data | ||
|
||
lr_scheduled = group['lr'] | ||
lr_scheduled *= group['schedule'].get_lr(state['step']) | ||
|
||
update_with_lr = lr_scheduled * update | ||
p.data.add_(-update_with_lr) | ||
|
||
state['step'] += 1 | ||
|
||
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 | ||
# No bias correction | ||
# bias_correction1 = 1 - beta1 ** state['step'] | ||
# bias_correction2 = 1 - beta2 ** state['step'] | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""PyTorch optimization for OpenAI GPT model.""" | ||
|
||
import math | ||
import torch | ||
from torch.optim import Optimizer | ||
from torch.optim.optimizer import required | ||
from torch.nn.utils import clip_grad_norm_ | ||
import logging | ||
from .optimization import SCHEDULES, _LRSchedule, WarmupCosineWithWarmupRestartsSchedule, \ | ||
WarmupCosineWithHardRestartsSchedule, WarmupCosineSchedule, WarmupLinearSchedule, WarmupConstantSchedule | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class OpenAIAdam(Optimizer): | ||
"""Implements Open AI version of Adam algorithm with weight decay fix. | ||
""" | ||
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, | ||
b1=0.9, b2=0.999, e=1e-8, weight_decay=0, | ||
vector_l2=False, max_grad_norm=-1, **kwargs): | ||
if lr is not required and lr < 0.0: | ||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) | ||
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES: | ||
raise ValueError("Invalid schedule parameter: {}".format(schedule)) | ||
if not 0.0 <= b1 < 1.0: | ||
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) | ||
if not 0.0 <= b2 < 1.0: | ||
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) | ||
if not e >= 0.0: | ||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) | ||
# initialize schedule object | ||
if not isinstance(schedule, _LRSchedule): | ||
schedule_type = SCHEDULES[schedule] | ||
schedule = schedule_type(warmup=warmup, t_total=t_total) | ||
else: | ||
if warmup != -1 or t_total != -1: | ||
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. " | ||
"Please specify custom warmup and t_total in _LRSchedule object.") | ||
defaults = dict(lr=lr, schedule=schedule, | ||
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, | ||
max_grad_norm=max_grad_norm) | ||
super(OpenAIAdam, self).__init__(params, defaults) | ||
|
||
def get_lr(self): | ||
lr = [] | ||
for group in self.param_groups: | ||
for p in group['params']: | ||
state = self.state[p] | ||
if len(state) == 0: | ||
return [0] | ||
lr_scheduled = group['lr'] | ||
lr_scheduled *= group['schedule'].get_lr(state['step']) | ||
lr.append(lr_scheduled) | ||
return lr | ||
|
||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
grad = p.grad.data | ||
if grad.is_sparse: | ||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | ||
|
||
state = self.state[p] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state['step'] = 0 | ||
# Exponential moving average of gradient values | ||
state['exp_avg'] = torch.zeros_like(p.data) | ||
# Exponential moving average of squared gradient values | ||
state['exp_avg_sq'] = torch.zeros_like(p.data) | ||
|
||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||
beta1, beta2 = group['b1'], group['b2'] | ||
|
||
state['step'] += 1 | ||
|
||
# Add grad clipping | ||
if group['max_grad_norm'] > 0: | ||
clip_grad_norm_(p, group['max_grad_norm']) | ||
|
||
# Decay the first and second moment running average coefficient | ||
exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||
denom = exp_avg_sq.sqrt().add_(group['e']) | ||
|
||
bias_correction1 = 1 - beta1 ** state['step'] | ||
bias_correction2 = 1 - beta2 ** state['step'] | ||
|
||
lr_scheduled = group['lr'] | ||
lr_scheduled *= group['schedule'].get_lr(state['step']) | ||
|
||
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 | ||
|
||
p.data.addcdiv_(-step_size, exp_avg, denom) | ||
|
||
# Add weight decay at the end (fixed version) | ||
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: | ||
p.data.add_(-lr_scheduled * group['weight_decay'], p.data) | ||
|
||
return loss |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tokenization classes for OpenAI GPT.""" | ||
from __future__ import (absolute_import, division, print_function, | ||
unicode_literals) | ||
|
||
import sys | ||
import json | ||
import logging | ||
import os | ||
import regex as re | ||
from io import open | ||
|
||
try: | ||
from functools import lru_cache | ||
except ImportError: | ||
# Just a dummy decorator to get the checks to run on python2 | ||
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. | ||
def lru_cache(): | ||
return lambda func: func | ||
|
||
from .file_utils import cached_path | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
PRETRAINED_VOCAB_ARCHIVE_MAP = { | ||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", | ||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", | ||
} | ||
PRETRAINED_MERGES_ARCHIVE_MAP = { | ||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", | ||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", | ||
} | ||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { | ||
'gpt2': 1024, | ||
} | ||
VOCAB_NAME = 'vocab.json' | ||
MERGES_NAME = 'merges.txt' | ||
SPECIAL_TOKENS_NAME = 'special_tokens.txt' | ||
|
||
@lru_cache() | ||
def bytes_to_unicode(): | ||
""" | ||
Returns list of utf-8 byte and a corresponding list of unicode strings. | ||
The reversible bpe codes work on unicode strings. | ||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | ||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | ||
This is a signficant percentage of your normal, say, 32K bpe vocab. | ||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | ||
And avoids mapping to whitespace/control characters the bpe code barfs on. | ||
""" | ||
_chr = unichr if sys.version_info[0] == 2 else chr | ||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) | ||
cs = bs[:] | ||
n = 0 | ||
for b in range(2**8): | ||
if b not in bs: | ||
bs.append(b) | ||
cs.append(2**8+n) | ||
n += 1 | ||
cs = [_chr(n) for n in cs] | ||
return dict(zip(bs, cs)) | ||
|
||
def get_pairs(word): | ||
"""Return set of symbol pairs in a word. | ||
Word is represented as tuple of symbols (symbols being variable-length strings). | ||
""" | ||
pairs = set() | ||
prev_char = word[0] | ||
for char in word[1:]: | ||
pairs.add((prev_char, char)) | ||
prev_char = char | ||
return pairs | ||
|
||
class GPT2Tokenizer(object): | ||
""" | ||
GPT-2 BPE tokenizer. Peculiarities: | ||
- Byte-level BPE | ||
""" | ||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): | ||
""" | ||
Instantiate a GPT2Tokenizer from a pre-trained model file. | ||
Download and cache the pre-trained model file if needed. | ||
""" | ||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: | ||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] | ||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] | ||
special_tokens_file = None | ||
else: | ||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) | ||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) | ||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) | ||
if not os.path.exists(special_tokens_file): | ||
special_tokens_file = None | ||
else: | ||
logger.info("loading special tokens file {}".format(special_tokens_file)) | ||
# redirect to the cache, if necessary | ||
try: | ||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) | ||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) | ||
except EnvironmentError: | ||
logger.error( | ||
"Model name '{}' was not found in model name list ({}). " | ||
"We assumed '{}' was a path or url but couldn't find files {} and {} " | ||
"at this path or url.".format( | ||
pretrained_model_name_or_path, | ||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), | ||
pretrained_model_name_or_path, | ||
vocab_file, merges_file)) | ||
return None | ||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: | ||
logger.info("loading vocabulary file {}".format(vocab_file)) | ||
logger.info("loading merges file {}".format(merges_file)) | ||
else: | ||
logger.info("loading vocabulary file {} from cache at {}".format( | ||
vocab_file, resolved_vocab_file)) | ||
logger.info("loading merges file {} from cache at {}".format( | ||
merges_file, resolved_merges_file)) | ||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: | ||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer | ||
# than the number of positional embeddings | ||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] | ||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | ||
# Instantiate tokenizer. | ||
if special_tokens_file and 'special_tokens' not in kwargs: | ||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] | ||
else: | ||
special_tokens = kwargs.pop('special_tokens', []) | ||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) | ||
return tokenizer | ||
|
||
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): | ||
self.max_len = max_len if max_len is not None else int(1e12) | ||
self.encoder = json.load(open(vocab_file)) | ||
self.decoder = {v:k for k,v in self.encoder.items()} | ||
self.errors = errors # how to handle errors in decoding | ||
self.byte_encoder = bytes_to_unicode() | ||
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} | ||
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] | ||
bpe_merges = [tuple(merge.split()) for merge in bpe_data] | ||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | ||
self.cache = {} | ||
|
||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | ||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | ||
|
||
self.special_tokens = {} | ||
self.special_tokens_decoder = {} | ||
self.set_special_tokens(special_tokens) | ||
|
||
def __len__(self): | ||
return len(self.encoder) + len(self.special_tokens) | ||
|
||
def set_special_tokens(self, special_tokens): | ||
""" Add a list of additional tokens to the encoder. | ||
The additional tokens are indexed starting from the last index of the | ||
current vocabulary in the order of the `special_tokens` list. | ||
""" | ||
if not special_tokens: | ||
self.special_tokens = {} | ||
self.special_tokens_decoder = {} | ||
return | ||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) | ||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} | ||
logger.info("Special tokens {}".format(self.special_tokens)) | ||
|
||
def bpe(self, token): | ||
if token in self.cache: | ||
return self.cache[token] | ||
word = tuple(token) | ||
pairs = get_pairs(word) | ||
|
||
if not pairs: | ||
return token | ||
|
||
while True: | ||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) | ||
if bigram not in self.bpe_ranks: | ||
break | ||
first, second = bigram | ||
new_word = [] | ||
i = 0 | ||
while i < len(word): | ||
try: | ||
j = word.index(first, i) | ||
new_word.extend(word[i:j]) | ||
i = j | ||
except: | ||
new_word.extend(word[i:]) | ||
break | ||
|
||
if word[i] == first and i < len(word)-1 and word[i+1] == second: | ||
new_word.append(first+second) | ||
i += 2 | ||
else: | ||
new_word.append(word[i]) | ||
i += 1 | ||
new_word = tuple(new_word) | ||
word = new_word | ||
if len(word) == 1: | ||
break | ||
else: | ||
pairs = get_pairs(word) | ||
word = ' '.join(word) | ||
self.cache[token] = word | ||
return word | ||
|
||
def tokenize(self, text): | ||
""" Tokenize a string. """ | ||
bpe_tokens = [] | ||
for token in re.findall(self.pat, text): | ||
if sys.version_info[0] == 2: | ||
token = ''.join(self.byte_encoder[ord(b)] for b in token) | ||
else: | ||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | ||
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) | ||
return bpe_tokens | ||
|
||
def convert_tokens_to_ids(self, tokens): | ||
""" Converts a sequence of tokens into ids using the vocab. """ | ||
ids = [] | ||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): | ||
if tokens in self.special_tokens: | ||
return self.special_tokens[tokens] | ||
else: | ||
return self.encoder.get(tokens, 0) | ||
for token in tokens: | ||
if token in self.special_tokens: | ||
ids.append(self.special_tokens[token]) | ||
else: | ||
ids.append(self.encoder.get(token, 0)) | ||
if len(ids) > self.max_len: | ||
logger.warning( | ||
"Token indices sequence length is longer than the specified maximum " | ||
" sequence length for this OpenAI GPT model ({} > {}). Running this" | ||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len) | ||
) | ||
return ids | ||
|
||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | ||
"""Converts a sequence of ids in BPE tokens using the vocab.""" | ||
tokens = [] | ||
for i in ids: | ||
if i in self.special_tokens_decoder: | ||
if not skip_special_tokens: | ||
tokens.append(self.special_tokens_decoder[i]) | ||
else: | ||
tokens.append(self.decoder[i]) | ||
return tokens | ||
|
||
def encode(self, text): | ||
return self.convert_tokens_to_ids(self.tokenize(text)) | ||
|
||
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True): | ||
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens)) | ||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) | ||
if clean_up_tokenization_spaces: | ||
text = text.replace('<unk>', '') | ||
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',' | ||
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" | ||
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") | ||
return text | ||
|
||
def save_vocabulary(self, vocab_path): | ||
"""Save the tokenizer vocabulary and merge files to a directory.""" | ||
if not os.path.isdir(vocab_path): | ||
logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) | ||
return | ||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | ||
merge_file = os.path.join(vocab_path, MERGES_NAME) | ||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) | ||
|
||
with open(vocab_file, 'w', encoding='utf-8') as f: | ||
f.write(json.dumps(self.encoder, ensure_ascii=False)) | ||
|
||
index = 0 | ||
with open(merge_file, "w", encoding="utf-8") as writer: | ||
writer.write(u'#version: 0.2\n') | ||
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): | ||
if index != token_index: | ||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." | ||
" Please check that the tokenizer is not corrupted!".format(merge_file)) | ||
index = token_index | ||
writer.write(' '.join(bpe_tokens) + u'\n') | ||
index += 1 | ||
|
||
index = len(self.encoder) | ||
with open(special_tokens_file, 'w', encoding='utf-8') as writer: | ||
for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): | ||
if index != token_index: | ||
logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." | ||
" Please check that the tokenizer is not corrupted!".format(special_tokens_file)) | ||
index = token_index | ||
writer.write(token + u'\n') | ||
index += 1 | ||
|
||
return vocab_file, merge_file, special_tokens_file |
Oops, something went wrong.