Skip to content

Commit 0856a23

Browse files
authored
Merge pull request huggingface#287 from huggingface/gpt2
Gpt2
2 parents 3a2f97d + ab7f5d2 commit 0856a23

11 files changed

+1588
-16
lines changed

README.md

Lines changed: 194 additions & 9 deletions
Large diffs are not rendered by default.

examples/run_gpt2.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import logging
5+
from tqdm import trange
6+
7+
import torch
8+
import torch.nn.functional as F
9+
import numpy as np
10+
11+
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
12+
13+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
14+
datefmt = '%m/%d/%Y %H:%M:%S',
15+
level = logging.INFO)
16+
logger = logging.getLogger(__name__)
17+
18+
def top_k_logits(logits, k):
19+
if k == 0:
20+
return logits
21+
values, _ = torch.topk(logits, k)
22+
min_values = values[:, -1]
23+
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
24+
25+
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
26+
if start_token is None:
27+
assert context is not None, 'Specify exactly one of start_token and context!'
28+
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
29+
else:
30+
assert context is None, 'Specify exactly one of start_token and context!'
31+
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
32+
prev = context
33+
output = context
34+
past = None
35+
with torch.no_grad():
36+
for i in trange(length):
37+
logits, past = model(prev, past=past)
38+
logits = logits[:, -1, :] / temperature
39+
logits = top_k_logits(logits, k=top_k)
40+
log_probs = F.softmax(logits, dim=-1)
41+
if sample:
42+
prev = torch.multinomial(log_probs, num_samples=1)
43+
else:
44+
_, prev = torch.topk(log_probs, k=1, dim=-1)
45+
output = torch.cat((output, prev), dim=1)
46+
return output
47+
48+
def run_model():
49+
parser = argparse.ArgumentParser()
50+
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
51+
parser.add_argument("--seed", type=int, default=0)
52+
parser.add_argument("--nsamples", type=int, default=1)
53+
parser.add_argument("--batch_size", type=int, default=-1)
54+
parser.add_argument("--length", type=int, default=-1)
55+
parser.add_argument("--temperature", type=int, default=1)
56+
parser.add_argument("--top_k", type=int, default=0)
57+
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
58+
args = parser.parse_args()
59+
print(args)
60+
61+
if args.batch_size == -1:
62+
args.batch_size = 1
63+
assert args.nsamples % args.batch_size == 0
64+
65+
np.random.seed(args.seed)
66+
torch.random.manual_seed(args.seed)
67+
torch.cuda.manual_seed(args.seed)
68+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69+
70+
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
71+
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
72+
model.to(device)
73+
model.eval()
74+
75+
if args.length == -1:
76+
args.length = model.config.n_ctx // 2
77+
elif args.length > model.config.n_ctx:
78+
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
79+
80+
while not args.unconditional:
81+
if not args.unconditional:
82+
raw_text = input("Model prompt >>> ")
83+
while not raw_text:
84+
print('Prompt should not be empty!')
85+
raw_text = input("Model prompt >>> ")
86+
context_tokens = enc.encode(raw_text)
87+
generated = 0
88+
for _ in range(args.nsamples // args.batch_size):
89+
out = sample_sequence(
90+
model=model, length=args.length,
91+
context=context_tokens if not args.unconditional else None,
92+
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
93+
batch_size=args.batch_size,
94+
temperature=args.temperature, top_k=args.top_k, device=device
95+
)
96+
out = out[:, len(context_tokens):].tolist()
97+
for i in range(args.batch_size):
98+
generated += 1
99+
text = enc.decode(out[i])
100+
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
101+
print(text)
102+
print("=" * 80)
103+
104+
if __name__ == '__main__':
105+
run_model()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import logging
5+
6+
import torch
7+
import torch.nn.functional as F
8+
import numpy as np
9+
from tqdm import trange
10+
11+
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
12+
13+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
14+
datefmt = '%m/%d/%Y %H:%M:%S',
15+
level = logging.INFO)
16+
logger = logging.getLogger(__name__)
17+
18+
def top_k_logits(logits, k):
19+
if k == 0:
20+
return logits
21+
values, _ = torch.topk(logits, k)
22+
min_values = values[:, -1]
23+
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
24+
25+
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
26+
if start_token is None:
27+
assert context is not None, 'Specify exactly one of start_token and context!'
28+
context = torch.tensor(context, device=device, dtype=torch.long)
29+
else:
30+
assert context is None, 'Specify exactly one of start_token and context!'
31+
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
32+
prev = context
33+
output = context
34+
past = None
35+
with torch.no_grad():
36+
for i in trange(length):
37+
logits, past = model(prev, past=past)
38+
logits = logits[:, -1, :] / temperature
39+
logits = top_k_logits(logits, k=top_k)
40+
log_probs = F.softmax(logits, dim=-1)
41+
prev = torch.multinomial(log_probs, num_samples=1)
42+
output = torch.cat((output, prev), dim=1)
43+
return output
44+
45+
def sample_model():
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
48+
parser.add_argument("--seed", type=int, default=0)
49+
parser.add_argument("--nsamples", type=int, default=0)
50+
parser.add_argument("--batch_size", type=int, default=1)
51+
parser.add_argument("--length", type=int, default=-1)
52+
parser.add_argument("--temperature", type=int, default=1)
53+
parser.add_argument("--top_k", type=int, default=0)
54+
args = parser.parse_args()
55+
print(args)
56+
57+
np.random.seed(args.seed)
58+
torch.random.manual_seed(args.seed)
59+
torch.cuda.manual_seed(args.seed)
60+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61+
62+
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
63+
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
64+
model.to(device)
65+
model.eval()
66+
67+
if args.length == -1:
68+
args.length = model.config.n_ctx
69+
elif args.length > model.config.n_ctx:
70+
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
71+
72+
generated = 0
73+
while args.nsamples == 0 or generated < args.nsamples:
74+
out = sample_sequence(
75+
model=model, length=args.length,
76+
start_token=enc.encoder['<|endoftext|>'],
77+
batch_size=args.batch_size,
78+
temperature=args.temperature, top_k=args.top_k, device=device
79+
)
80+
out = out.tolist()
81+
for i in range(args.batch_size):
82+
generated += args.batch_size
83+
text = enc.decode(out[i])
84+
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
85+
print(text)
86+
87+
if __name__ == '__main__':
88+
sample_model()

pytorch_pretrained_bert/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
__version__ = "0.5.1"
1+
__version__ = "0.6.0"
22
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
33
from .tokenization_openai import OpenAIGPTTokenizer
44
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
5+
from .tokenization_gpt2 import GPT2Tokenizer
56

67
from .modeling import (BertConfig, BertModel, BertForPreTraining,
78
BertForMaskedLM, BertForNextSentencePrediction,
@@ -13,6 +14,9 @@
1314
load_tf_weights_in_openai_gpt)
1415
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
1516
load_tf_weights_in_transfo_xl)
17+
from .modeling_gpt2 import (GPT2Config, GPT2Model,
18+
GPT2LMHeadModel, GPT2DoubleHeadsModel,
19+
load_tf_weights_in_gpt2)
1620

1721
from .optimization import BertAdam
1822
from .optimization_openai import OpenAIAdam

pytorch_pretrained_bert/__main__.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ def main():
44
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
55
"convert_tf_checkpoint_to_pytorch",
66
"convert_openai_checkpoint",
7-
"convert_transfo_xl_checkpoint"
7+
"convert_transfo_xl_checkpoint",
8+
"convert_gpt2_checkpoint",
89
]:
910
print(
1011
"Should be used as one of: \n"
1112
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
12-
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
13-
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
13+
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
14+
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
15+
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
1416
else:
1517
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
1618
try:
@@ -40,7 +42,7 @@ def main():
4042
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
4143
OPENAI_GPT_CONFIG,
4244
PYTORCH_DUMP_OUTPUT)
43-
else:
45+
elif sys.argv[1] == "convert_transfo_xl_checkpoint":
4446
try:
4547
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
4648
except ImportError:
@@ -61,5 +63,21 @@ def main():
6163
else:
6264
TF_CONFIG = ""
6365
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
66+
else:
67+
try:
68+
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
69+
except ImportError:
70+
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
71+
"In that case, it requires TensorFlow to be installed. Please see "
72+
"https://www.tensorflow.org/install/ for installation instructions.")
73+
raise
74+
75+
TF_CHECKPOINT = sys.argv[2]
76+
PYTORCH_DUMP_OUTPUT = sys.argv[3]
77+
if len(sys.argv) == 5:
78+
TF_CONFIG = sys.argv[4]
79+
else:
80+
TF_CONFIG = ""
81+
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
6482
if __name__ == '__main__':
6583
main()
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# coding=utf-8
2+
# Copyright 2018 The HugginFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Convert OpenAI GPT checkpoint."""
16+
17+
from __future__ import absolute_import, division, print_function
18+
19+
import argparse
20+
from io import open
21+
22+
import torch
23+
24+
from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
25+
GPT2Config,
26+
GPT2Model,
27+
load_tf_weights_in_gpt2)
28+
29+
30+
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
31+
# Construct model
32+
if gpt2_config_file == "":
33+
config = GPT2Config()
34+
else:
35+
config = GPT2Config(gpt2_config_file)
36+
model = GPT2Model(config)
37+
38+
# Load weights from numpy
39+
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
40+
41+
# Save pytorch-model
42+
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
43+
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
44+
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
45+
torch.save(model.state_dict(), pytorch_weights_dump_path)
46+
print("Save configuration file to {}".format(pytorch_config_dump_path))
47+
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
48+
f.write(config.to_json_string())
49+
50+
51+
if __name__ == "__main__":
52+
parser = argparse.ArgumentParser()
53+
## Required parameters
54+
parser.add_argument("--gpt2_checkpoint_path",
55+
default = None,
56+
type = str,
57+
required = True,
58+
help = "Path the TensorFlow checkpoint path.")
59+
parser.add_argument("--pytorch_dump_folder_path",
60+
default = None,
61+
type = str,
62+
required = True,
63+
help = "Path to the output PyTorch model.")
64+
parser.add_argument("--gpt2_config_file",
65+
default = "",
66+
type = str,
67+
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
68+
"This specifies the model architecture.")
69+
args = parser.parse_args()
70+
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path,
71+
args.gpt2_config_file,
72+
args.pytorch_dump_folder_path)

0 commit comments

Comments
 (0)