Skip to content

Commit

Permalink
update opts and code style.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhezhaoa committed Jan 6, 2021
1 parent 2121a37 commit ad105ec
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 149 deletions.
6 changes: 2 additions & 4 deletions preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse
import torch
from uer.utils.data import *
from uer.utils import *
import six
from packaging import version
from uer.utils.data import *
from uer.utils import *


assert version.parse(six.__version__) >= version.parse("1.12.0")
Expand Down Expand Up @@ -76,4 +75,3 @@ def main():

if __name__ == "__main__":
main()

32 changes: 6 additions & 26 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import uer.trainer as trainer
from uer.utils.config import load_hyperparam
from uer.opts import *


def main():
Expand Down Expand Up @@ -40,36 +41,21 @@ def main():
help="The buffer size of instances in memory.")
parser.add_argument("--labels_num", type=int, required=False,
help="Number of prediction labels.")

# Model options.
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout value.")
parser.add_argument("--seed", type=int, default=7, help="Random seed.")
parser.add_argument("--embedding", choices=["word", "word_pos", "word_pos_seg", "word_sinusoidalpos"], default="word_pos_seg",
help="Emebdding type.")

# Model options.
model_opts(parser)
parser.add_argument("--tgt_embedding", choices=["word", "word_pos", "word_pos_seg", "word_sinusoidalpos"], default="word_pos_seg",
help="Target embedding type.")
parser.add_argument("--remove_embedding_layernorm", action="store_true",
help="Remove layernorm on embedding.")
parser.add_argument("--encoder", choices=["transformer", "rnn", "lstm", "gru", \
"birnn", "bilstm", "bigru", \
"gatedcnn"], \
default="transformer", help="Encoder type.")
parser.add_argument("--decoder", choices=["transformer"], \
default="transformer", help="Decoder type.")
parser.add_argument("--pooling", choices=["mean", "max", "first", "last"], default="first",
help="Pooling type.")
parser.add_argument("--mask", choices=["fully_visible", "causal"], default="fully_visible",
help="Mask type.")
parser.add_argument("--layernorm_positioning", choices=["pre", "post"], default="post",
help="Layernorm positioning.")
parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.")
parser.add_argument("--target", choices=["bert", "lm", "mlm", "bilm", "albert", "mt", "t5", "cls"], default="bert",
help="The training target of the pretraining model.")
parser.add_argument("--tie_weights", action="store_true",
help="Tie the word embedding and softmax weights.")
parser.add_argument("--factorized_embedding_parameterization", action="store_true",
help="Factorized embedding parameterization.")
parser.add_argument("--parameter_sharing", action="store_true", help="Parameter sharing.")
parser.add_argument("--has_lmtarget_bias", action="store_true",
help="Add bias on output_layer for lm target.")

Expand All @@ -81,15 +67,9 @@ def main():
help="Max length for span masking.")

# Optimizer options.
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Initial learning rate.")
parser.add_argument("--warmup", type=float, default=0.1, help="Warm up value.")
optimization_opts(parser)
parser.add_argument("--beta1", type=float, default=0.9, help="Beta1 for Adam optimizer.")
parser.add_argument("--beta2", type=float, default=0.999, help="Beta2 for Adam optimizer.")
parser.add_argument("--fp16", action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument("--fp16_opt_level", choices=["O0", "O1", "O2", "O3" ], default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")

# GPU options.
parser.add_argument("--world_size", type=int, default=1, help="Total number of processes (GPUs) for training.")
Expand All @@ -101,7 +81,7 @@ def main():
args = parser.parse_args()

if args.target == "cls":
assert args.labels_num is not None, "Cls target needs the number of prediction labels."
assert args.labels_num is not None, "Cls target needs the denotation of the number of labels."

# Load hyper-parameters from config file.
if args.config_path:
Expand Down
7 changes: 3 additions & 4 deletions run_c3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
"""
import argparse
import json
import torch
import random
import torch
import torch.nn as nn
from uer.layers import *
from uer.encoders import *
from uer.utils.vocab import Vocab
from uer.utils.constants import *
from uer.utils import *
from uer.utils.optimizers import *
Expand Down Expand Up @@ -168,7 +167,7 @@ def main():
model = torch.nn.DataParallel(model)
args.model = model

total_loss, result, best_result = 0., 0., 0.
total_loss, result, best_result = 0.0, 0.0, 0.0

print("Start training.")

Expand All @@ -181,7 +180,7 @@ def main():

if (i + 1) % args.report_steps == 0:
print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
total_loss = 0.
total_loss = 0.0

result = evaluate(args, read_dataset(args, args.dev_path))
if result[0] > best_result:
Expand Down
48 changes: 22 additions & 26 deletions run_chid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
"""
import argparse
import json
import torch
import random
import torch.nn as nn
import torch
from uer.layers import *
from uer.encoders import *
from uer.utils.vocab import Vocab
from uer.utils.constants import *
from uer.utils.tokenizers import *
from uer.utils.optimizers import *
Expand All @@ -27,17 +25,17 @@ def tokenize_chid(text):
if first_idiom:
idiom_index = text.find("#idiom")
output.extend(text[:idiom_index])
output.append(text[idiom_index:idiom_index+13])
output.append(text[idiom_index : idiom_index + 13])
pre_idiom_index = idiom_index
first_idiom = False
else:
if text[idiom_index+1:].find("#idiom") == -1:
output.extend(text[pre_idiom_index+13:])
if text[idiom_index + 1 :].find("#idiom") == -1:
output.extend(text[pre_idiom_index + 13 :])
break
else:
idiom_index = idiom_index+1+text[idiom_index+1:].find("#idiom")
output.extend(text[pre_idiom_index+13:idiom_index])
output.append(text[idiom_index:idiom_index+13])
idiom_index = idiom_index + 1 + text[idiom_index + 1 :].find("#idiom")
output.extend(text[pre_idiom_index + 13 : idiom_index])
output.append(text[idiom_index : idiom_index + 13])
pre_idiom_index = idiom_index

return output
Expand All @@ -47,15 +45,15 @@ def add_tokens_around(tokens, idiom_index, tokens_num):
left_tokens_num = tokens_num // 2
right_tokens_num = tokens_num - left_tokens_num

if idiom_index >= left_tokens_num and (len(tokens)-1-idiom_index) >= right_tokens_num:
left_tokens = tokens[idiom_index-left_tokens_num: idiom_index]
right_tokens = tokens[idiom_index+1: idiom_index+1+right_tokens_num]
if idiom_index >= left_tokens_num and (len(tokens) - 1 - idiom_index) >= right_tokens_num:
left_tokens = tokens[idiom_index - left_tokens_num : idiom_index]
right_tokens = tokens[idiom_index + 1 : idiom_index + 1 + right_tokens_num]
elif idiom_index < left_tokens_num:
left_tokens = tokens[:idiom_index]
right_tokens = tokens[idiom_index+1: idiom_index+1+tokens_num-len(left_tokens)]
elif (len(tokens)-1-idiom_index) < right_tokens_num:
right_tokens = tokens[idiom_index+1:]
left_tokens = tokens[idiom_index-(tokens_num-len(right_tokens)): idiom_index]
right_tokens = tokens[idiom_index + 1 : idiom_index + 1 + tokens_num - len(left_tokens)]
elif (len(tokens) - 1 - idiom_index) < right_tokens_num:
right_tokens = tokens[idiom_index + 1 :]
left_tokens = tokens[idiom_index - (tokens_num - len(right_tokens)) : idiom_index]

return left_tokens, right_tokens

Expand All @@ -69,18 +67,18 @@ def read_dataset(args, data_path, answer_path):

for line in open(data_path, mode="r", encoding="utf-8"):
example = json.loads(line)
options = example['candidates']
for context in example['content']:
options = example["candidates"]
for context in example["content"]:
chid_tokens = tokenize_chid(context)
tags = [token for token in chid_tokens if '#idiom' in token]
tags = [token for token in chid_tokens if "#idiom" in token]
for tag in tags:
if answer_path is not None:
tgt = answers[tag]
else:
tgt = -1
tokens = []
for i, token in enumerate(chid_tokens):
if '#idiom' in token:
if "#idiom" in token:
sub_tokens = [str(token)]
else:
sub_tokens = args.tokenizer.tokenize(token)
Expand All @@ -90,17 +88,17 @@ def read_dataset(args, data_path, answer_path):
left_tokens, right_tokens = add_tokens_around(tokens, idiom_index, max_tokens_for_doc-1)

for i in range(len(left_tokens)):
if '#idiom' in left_tokens[i] and left_tokens[i] != tag:
if "#idiom" in left_tokens[i] and left_tokens[i] != tag:
left_tokens[i] = "[MASK]"
for i in range(len(right_tokens)):
if '#idiom' in right_tokens[i] and right_tokens[i] != tag:
if "#idiom" in right_tokens[i] and right_tokens[i] != tag:
right_tokens[i] = "[MASK]"

dataset.append(([], tgt, [], tag, group_index))

for option in options:
option_tokens = args.tokenizer.tokenize(option)
tokens = ['[CLS]'] + option_tokens + ['[SEP]'] + left_tokens + ['[unused1]'] + right_tokens + ['[SEP]']
tokens = ["[CLS]"] + option_tokens + ["[SEP]"] + left_tokens + ["[unused1]"] + right_tokens + ["[SEP]"]

src = args.tokenizer.convert_tokens_to_ids(tokens)[:args.seq_length]
seg = [0] * len(src)
Expand Down Expand Up @@ -136,8 +134,6 @@ def main():
args = parser.parse_args()

args.labels_num = args.max_choices_num
if args.output_model_path == None:
args.output_model_path = "./models/chid_model.bin"

# Load the hyperparameters from the config file.
args = load_hyperparam(args)
Expand Down Expand Up @@ -186,7 +182,7 @@ def main():
model = torch.nn.DataParallel(model)
args.model = model

total_loss, result, best_result = 0., 0., 0.
total_loss, result, best_result = 0.0, 0.0, 0.0

print("Start training.")

Expand Down
39 changes: 17 additions & 22 deletions run_classifier.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
This script provides an exmaple to wrap UER-py for classification.
"""
import torch
import random
import argparse
import collections
import torch
import torch.nn as nn
from uer.layers import *
from uer.encoders import *
Expand Down Expand Up @@ -103,20 +102,20 @@ def build_optimizer(args, model):
def batch_loader(batch_size, src, tgt, seg, soft_tgt=None):
instances_num = src.size()[0]
for i in range(instances_num // batch_size):
src_batch = src[i*batch_size: (i+1)*batch_size, :]
tgt_batch = tgt[i*batch_size: (i+1)*batch_size]
seg_batch = seg[i*batch_size: (i+1)*batch_size, :]
src_batch = src[i * batch_size : (i + 1) * batch_size, :]
tgt_batch = tgt[i * batch_size : (i + 1) * batch_size]
seg_batch = seg[i * batch_size : (i + 1) * batch_size, :]
if soft_tgt is not None:
soft_tgt_batch = soft_tgt[i*batch_size: (i+1)*batch_size, :]
soft_tgt_batch = soft_tgt[i * batch_size : (i + 1) * batch_size, :]
yield src_batch, tgt_batch, seg_batch, soft_tgt_batch
else:
yield src_batch, tgt_batch, seg_batch, None
if instances_num > instances_num // batch_size * batch_size:
src_batch = src[instances_num//batch_size*batch_size:, :]
tgt_batch = tgt[instances_num//batch_size*batch_size:]
seg_batch = seg[instances_num//batch_size*batch_size:, :]
src_batch = src[instances_num // batch_size * batch_size :, :]
tgt_batch = tgt[instances_num // batch_size * batch_size :]
seg_batch = seg[instances_num // batch_size * batch_size :, :]
if soft_tgt is not None:
soft_tgt_batch = soft_tgt[instances_num//batch_size*batch_size:, :]
soft_tgt_batch = soft_tgt[instances_num // batch_size * batch_size :, :]
yield src_batch, tgt_batch, seg_batch, soft_tgt_batch
else:
yield src_batch, tgt_batch, seg_batch, None
Expand Down Expand Up @@ -146,8 +145,8 @@ def read_dataset(args, path):
seg = [1] * len(src_a) + [2] * len(src_b)

if len(src) > args.seq_length:
src = src[:args.seq_length]
seg = seg[:args.seq_length]
src = src[: args.seq_length]
seg = seg[: args.seq_length]
while len(src) < args.seq_length:
src.append(0)
seg.append(0)
Expand Down Expand Up @@ -190,7 +189,6 @@ def evaluate(args, dataset, print_confusion_matrix=False):
seg = torch.LongTensor([sample[2] for sample in dataset])

batch_size = args.batch_size
instances_num = src.size()[0]

correct = 0
# Confusion matrix.
Expand All @@ -203,7 +201,7 @@ def evaluate(args, dataset, print_confusion_matrix=False):
tgt_batch = tgt_batch.to(args.device)
seg_batch = seg_batch.to(args.device)
with torch.no_grad():
loss, logits = args.model(src_batch, tgt_batch, seg_batch)
_, logits = args.model(src_batch, tgt_batch, seg_batch)
pred = torch.argmax(nn.Softmax(dim=1)(logits), dim=1)
gold = tgt_batch
for j in range(pred.size()[0]):
Expand All @@ -215,13 +213,13 @@ def evaluate(args, dataset, print_confusion_matrix=False):
print(confusion)
print("Report precision, recall, and f1:")
for i in range(confusion.size()[0]):
p = confusion[i,i].item()/confusion[i,:].sum().item()
r = confusion[i,i].item()/confusion[:,i].sum().item()
f1 = 2*p*r / (p+r)
p = confusion[i,i].item() / confusion[i, :].sum().item()
r = confusion[i,i].item() / confusion[:, i].sum().item()
f1 = 2 * p * r / (p + r)
print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i,p,r,f1))

print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(correct/len(dataset), correct, len(dataset)))
return correct/len(dataset), confusion
return correct / len(dataset), confusion


def main():
Expand All @@ -246,9 +244,6 @@ def main():

args = parser.parse_args()

if args.output_model_path == None:
args.output_model_path = "./models/classifier_model.bin"

# Load the hyperparameters from the config file.
args = load_hyperparam(args)

Expand Down Expand Up @@ -303,7 +298,7 @@ def main():
model = torch.nn.DataParallel(model)
args.model = model

total_loss, result, best_result = 0., 0., 0.
total_loss, result, best_result = 0.0, 0.0, 0.0

print("Start training.")

Expand Down
Loading

0 comments on commit ad105ec

Please sign in to comment.