Skip to content

Commit

Permalink
pure training
Browse files Browse the repository at this point in the history
  • Loading branch information
lipiji committed Aug 21, 2020
1 parent f5305c5 commit 43f2fd3
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 72 deletions.
9 changes: 2 additions & 7 deletions biglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from label_smoothing import LabelSmoothing

class BIGLM(nn.Module):
def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropout, layers, smoothing_factor, approx):
def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropout, layers, smoothing_factor, approx=None):
super(BIGLM, self).__init__()
self.vocab = vocab
self.embed_dim = embed_dim
Expand All @@ -29,12 +29,7 @@ def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropou
self.dropout = dropout
self.device = local_rank

if approx == "none":
self.approx = None
elif approx == "adaptive":
self.approx = nn.AdaptiveLogSoftmaxWithLoss(self.embed_dim, self.vocab.size, [10000, 20000, 200000])
else:
raise NotImplementedError("%s has not been implemented"%approx)
self.approx = approx
self.reset_parameters()

def reset_parameters(self):
Expand Down
2 changes: 2 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def s2xy(lines, vocab, max_len, min_len):
data = []
for line in lines:
res = parse_line(line, max_len, min_len)
if not res:
continue
data.append(res)
return batchify(data, vocab)

Expand Down
95 changes: 51 additions & 44 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch.multiprocessing as mp

from biglm import BIGLM
from data import Vocab, DataLoader
from adam import AdamWeightDecayOptimizer
from data import Vocab, DataLoader, s2xy
from optim import Optim

import argparse, os
Expand All @@ -22,6 +21,7 @@ def parse_config():
parser.add_argument('--dropout', type=float)

parser.add_argument('--train_data', type=str)
parser.add_argument('--dev_data', type=str)
parser.add_argument('--vocab', type=str)
parser.add_argument('--min_occur_cnt', type=int)
parser.add_argument('--batch_size', type=int)
Expand All @@ -36,8 +36,6 @@ def parse_config():
parser.add_argument('--start_from', type=str, default=None)
parser.add_argument('--save_dir', type=str)

parser.add_argument('--approx', type=str, default='none')
parser.add_argument('--fp16', action='store_true')
parser.add_argument('--world_size', type=int)
parser.add_argument('--gpus', type=int)
parser.add_argument('--MASTER_ADDR', type=str)
Expand All @@ -64,53 +62,59 @@ def average_gradients(model):
break
return normal

def eval_epoch(lm_args, model, lm_vocab, local_rank, label):
print("validating...", flush=True)
ds = []
with open(lm_args.dev_data, "r") as f:
for line in f:
line = line.strip()
if line:
ds.append(line)

batch_size = 10
batches = round(len(ds) / batch_size)
idx = 0
avg_nll = 0.
avg_ppl = 0.
count = 0.
while idx < len(ds):
cplb = ds[idx:idx + batch_size]
xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, lm_args.min_len)

xs_tpl = xs_tpl.cuda(local_rank)
xs_seg = xs_seg.cuda(local_rank)
xs_pos = xs_pos.cuda(local_rank)
ys_truth = ys_truth.cuda(local_rank)
ys_inp = ys_inp.cuda(local_rank)
ys_tpl = ys_tpl.cuda(local_rank)
ys_seg = ys_seg.cuda(local_rank)
ys_pos = ys_pos.cuda(local_rank)
msk = msk.cuda(local_rank)

nll, ppl, bsz = model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)

avg_nll += nll
avg_ppl += ppl
count += bsz

idx += batch_size

print(label, "nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count, flush=True)

def run(args, local_rank):
""" Distributed Synchronous """
torch.manual_seed(1234)
vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
if (args.world_size == 1 or dist.get_rank() == 0):
print (vocab.size, flush=True)
print ("vocab.size = " + str(vocab.size), flush=True)
model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
args.num_heads, args.dropout, args.layers, args.smoothing, args.approx)
args.num_heads, args.dropout, args.layers, args.smoothing)
if args.start_from is not None:
ckpt = torch.load(args.start_from, map_location='cpu')
model.load_state_dict(ckpt['model'])
model = model.cuda(local_rank)

weight_decay_params = []
no_weight_decay_params = []

for name, param in model.named_parameters():
if name.endswith('bias') or 'layer_norm' in name:
no_weight_decay_params.append(param)
else:
weight_decay_params.append(param)
grouped_params = [{'params':weight_decay_params, 'weight_decay':args.weight_decay},
{'params':no_weight_decay_params, 'weight_decay':0.}]
if args.world_size > 1:
torch.manual_seed(1234 + dist.get_rank())
random.seed(5678 + dist.get_rank())

if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
optimizer = FusedAdam(grouped_params,
lr=args.lr,
betas=(0.9, 0.999),
eps =1e-6,
bias_correction=False,
max_grad_norm=1.0)
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)

else:
if args.weight_decay > 0:
optimizer = AdamWeightDecayOptimizer(grouped_params,
lr=args.lr, betas=(0.9, 0.999), eps=1e-6)
else:
optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))
optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))

if args.start_from is not None:
optimizer.load_state_dict(ckpt['optimizer'])
Expand Down Expand Up @@ -143,10 +147,8 @@ def run(args, local_rank):
ntokens_acm += ntokens
npairs_acm += npairs
nxs += npairs
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()

loss.backward()
if args.world_size > 1:
is_normal = average_gradients(model)
else:
Expand All @@ -165,6 +167,11 @@ def run(args, local_rank):
if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)

model.eval()
eval_epoch(args, model, vocab, local_rank, "epoch-" + str(train_data.epoch_id) + "-acm-" + str(batch_acm))
model.train()

torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))

def init_processes(args, local_rank, fn, backend='nccl'):
Expand Down
5 changes: 3 additions & 2 deletions train.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
CUDA_VISIBLE_DEVICES=0 \
CUDA_VISIBLE_DEVICES=1 \
python3 -u train.py --embed_dim 768 \
--ff_embed_dim 3072 \
--num_heads 12 \
--layers 12 \
--dropout 0.2 \
--train_data ./data/train.txt \
--dev_data ./data/dev.txt \
--vocab ./data/vocab.txt \
--min_occur_cnt 1 \
--batch_size 2 \
--batch_size 32 \
--warmup_steps 8000 \
--lr 0.5 \
--weight_decay 0 \
Expand Down
35 changes: 16 additions & 19 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,22 @@ def gelu(x):
cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
return cdf*x

try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
except ImportError:
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.Tensor(hidden_size))
self.bias = nn.Parameter(torch.Tensor(hidden_size))
self.eps = eps
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.weight, 1.)
nn.init.constant_(self.bias, 0.)

def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight * x + self.bias
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.Tensor(hidden_size))
self.bias = nn.Parameter(torch.Tensor(hidden_size))
self.eps = eps
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.weight, 1.)
nn.init.constant_(self.bias, 0.)

def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight * x + self.bias


INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
Expand Down

0 comments on commit 43f2fd3

Please sign in to comment.